diff --git a/README.md b/README.md
index 88347476..00f7c812 100644
--- a/README.md
+++ b/README.md
@@ -1,92 +1,138 @@
-# `dm_control`: The DeepMind Control Suite and Package
+# `dm_control`: Google DeepMind Infrastructure for Physics-Based Simulation.
-# 
+Google DeepMind's software stack for physics-based simulation and Reinforcement
+Learning environments, using MuJoCo physics.
-This package contains:
+An **introductory tutorial** for this package is available as a Colaboratory
+notebook:
+[](https://colab.research.google.com/github/google-deepmind/dm_control/blob/main/tutorial.ipynb)
-- A set of Python Reinforcement Learning environments powered by the MuJoCo
- physics engine. See the `suite` subdirectory.
+## Overview
-- Libraries that provide Python bindings to the MuJoCo physics engine.
+This package consists of the following "core" components:
-If you use this package, please cite our accompanying accompanying [tech report](https://arxiv.org/abs/1801.00690).
+- [`dm_control.mujoco`]: Libraries that provide Python bindings to the MuJoCo
+ physics engine.
-## Installation and requirements
+- [`dm_control.suite`]: A set of Python Reinforcement Learning environments
+ powered by the MuJoCo physics engine.
-Follow these steps to install `dm_control`:
+- [`dm_control.viewer`]: An interactive environment viewer.
-1. Download MuJoCo Pro 1.50 from the download page on the [MuJoCo website](http://www.mujoco.org/).
- MuJoCo Pro must be installed before `dm_control`, since `dm_control`'s
- install script generates Python [`ctypes`](https://docs.python.org/2/library/ctypes.html)
- bindings based on MuJoCo's header files. By default, `dm_control` assumes
- that the MuJoCo Zip archive is extracted as `~/.mujoco/mjpro150`.
+Additionally, the following components are available for the creation of more
+complex control tasks:
-2. Install the `dm_control` Python package by running
- `pip install git+git://github.com/deepmind/dm_control.git`
- (PyPI package coming soon) or by cloning the repository and running
- `pip install /path/to/dm_control/`
- At installation time, `dm_control` looks for the MuJoCo headers from Step 1
- in `~/.mujoco/mjpro150/include`, however this path can be configured with the
- `headers-dir` command line argument.
+- [`dm_control.mjcf`]: A library for composing and modifying MuJoCo MJCF
+ models in Python.
-3. Install a license key for MuJoCo, required by `dm_control` at runtime. See
- the [MuJoCo license key page](https://www.roboti.us/license.html) for further
- details. By default, `dm_control` looks for the MuJoCo license key file at
- `~/.mujoco/mjkey.txt`.
+- `dm_control.composer`: A library for defining rich RL environments from
+ reusable, self-contained components.
-4. If the license key (e.g. `mjkey.txt`) or the shared library provided by
- MuJoCo Pro (e.g. `libmujoco150.so` or `libmujoco150.dylib`) are installed at
- non-default paths, specify their locations using the `MJKEY_PATH` and
- `MJLIB_PATH` environment variables respectively.
+- [`dm_control.locomotion`]: Additional libraries for custom tasks.
-## Additional instructions for Linux
+- [`dm_control.locomotion.soccer`]: Multi-agent soccer tasks.
-Install `GLFW` and `GLEW` through your Linux distribution's package manager.
-For example, on Debian and Ubuntu, this can be done by running
-`sudo apt-get install libglfw3 libglew2.0`.
+If you use this package, please cite our accompanying [publication]:
-## Additional instructions for Homebrew users on macOS
-
-1. The above instructions using `pip` should work, provided that you
- use a Python interpreter that is installed by Homebrew (rather than the
- system-default one).
-
-2. To get OpenGL working, install the `glfw` package from Homebrew by running
- `brew install glfw`.
-
-3. Before running, the `DYLD_LIBRARY_PATH` environment variable needs to be
- updated with the path to the GLFW library. This can be done by running
- `export DYLD_LIBRARY_PATH=$(brew --prefix)/lib:$DYLD_LIBRARY_PATH`.
-
-## Control Suite quickstart
-
-```python
-from dm_control import suite
-import numpy as np
+```
+@article{tunyasuvunakool2020,
+ title = {dm_control: Software and tasks for continuous control},
+ journal = {Software Impacts},
+ volume = {6},
+ pages = {100022},
+ year = {2020},
+ issn = {2665-9638},
+ doi = {https://doi.org/10.1016/j.simpa.2020.100022},
+ url = {https://www.sciencedirect.com/science/article/pii/S2665963820300099},
+ author = {Saran Tunyasuvunakool and Alistair Muldal and Yotam Doron and Salvador Arturo Ortiz Gonzalez
+ Siqi Liu and Steven Bohez and Josh Merel and Tom Erez and
+ Timothy Lillicrap and Nicolas Heess and Yuval Tassa},
+}
+```
-# Load one task:
-env = suite.load(domain_name="cartpole", task_name="swingup")
+## Installation
-# Iterate over a task set:
-for domain_name, task_name in suite.BENCHMARKING:
- env = suite.load(domain_name, task_name)
+Install `dm_control` from PyPI by running
-# Step through an episode and print out reward, discount and observation.
-action_spec = env.action_spec()
-time_step = env.reset()
-while not time_step.last():
- action = np.random.uniform(action_spec.minimum,
- action_spec.maximum,
- size=action_spec.shape)
- time_step = env.step(action)
- print(time_step.reward, time_step.discount, time_step.observation)
+```sh
+pip install dm_control
```
-See our [tech report](https://arxiv.org/abs/1801.00690) for further details.
+> **Note**: **`dm_control` cannot be installed in "editable" mode** (i.e. `pip
+> install -e`).
+>
+> While `dm_control` has been largely updated to use the pybind11-based bindings
+> provided via the `mujoco` package, at this time it still relies on some legacy
+> components that are automatically generated from MuJoCo header files in a way
+> that is incompatible with editable mode. Attempting to install `dm_control` in
+> editable mode will result in import errors like:
+>
+> ```
+> ImportError: cannot import name 'constants' from partially initialized module 'dm_control.mujoco.wrapper.mjbindings' ...
+> ```
+>
+> The solution is to `pip uninstall dm_control` and then reinstall it without
+> the `-e` flag.
+
+## Versioning
+
+Starting from version 1.0.0, we adopt semantic versioning.
+
+Prior to version 1.0.0, the `dm_control` Python package was versioned `0.0.N`,
+where `N` was an internal revision number that increased by an arbitrary amount
+at every single Git commit.
+
+If you want to install an unreleased version of `dm_control` directly from our
+repository, you can do so by running `pip install
+git+https://github.com/google-deepmind/dm_control.git`.
+
+## Rendering
+
+The MuJoCo Python bindings support three different OpenGL rendering backends:
+EGL (headless, hardware-accelerated), GLFW (windowed, hardware-accelerated), and
+OSMesa (purely software-based). At least one of these three backends must be
+available in order render through `dm_control`.
+
+* Hardware rendering with a windowing system is supported via GLFW and GLEW.
+ On Linux these can be installed using your distribution's package manager.
+ For example, on Debian and Ubuntu, this can be done by running `sudo apt-get
+ install libglfw3 libglew2.0`. Please note that:
+
+ - [`dm_control.viewer`] can only be used with GLFW.
+ - GLFW will not work on headless machines.
+
+* "Headless" hardware rendering (i.e. without a windowing system such as X11)
+ requires [EXT_platform_device] support in the EGL driver. Recent Nvidia
+ drivers support this. You will also need GLEW. On Debian and Ubuntu, this
+ can be installed via `sudo apt-get install libglew2.0`.
+
+* Software rendering requires GLX and OSMesa. On Debian and Ubuntu these can
+ be installed using `sudo apt-get install libgl1-mesa-glx libosmesa6`.
+
+By default, `dm_control` will attempt to use GLFW first, then EGL, then OSMesa.
+You can also specify a particular backend to use by setting the `MUJOCO_GL=`
+environment variable to `"glfw"`, `"egl"`, or `"osmesa"`, respectively. When
+rendering with EGL, you can also specify which GPU to use for rendering by
+setting the environment variable `MUJOCO_EGL_DEVICE_ID=` to the target GPU ID.
-## Illustration video
-
-Below is a video montage of solved Control Suite tasks, with reward
-visualisation enabled.
+## Additional instructions for Homebrew users on macOS
-[](https://www.youtube.com/watch?v=rAai4QzcYbs)
+1. The above instructions using `pip` should work, provided that you use a
+ Python interpreter that is installed by Homebrew (rather than the
+ system-default one).
+
+2. Before running, the `DYLD_LIBRARY_PATH` environment variable needs to be
+ updated with the path to the GLFW library. This can be done by running
+ `export DYLD_LIBRARY_PATH=$(brew --prefix)/lib:$DYLD_LIBRARY_PATH`.
+
+[EXT_platform_device]: https://www.khronos.org/registry/EGL/extensions/EXT/EGL_EXT_platform_device.txt
+[Releases page on the MuJoCo GitHub repository]: https://github.com/google-deepmind/mujoco/releases
+[MuJoCo website]: https://mujoco.org/
+[publication]: https://doi.org/10.1016/j.simpa.2020.100022
+[`ctypes`]: https://docs.python.org/3/library/ctypes.html
+[`dm_control.mjcf`]: dm_control/mjcf/README.md
+[`dm_control.mujoco`]: dm_control/mujoco/README.md
+[`dm_control.suite`]: dm_control/suite/README.md
+[`dm_control.viewer`]: dm_control/viewer/README.md
+[`dm_control.locomotion`]: dm_control/locomotion/README.md
+[`dm_control.locomotion.soccer`]: dm_control/locomotion/soccer/README.md
diff --git a/dm_control/_render/__init__.py b/dm_control/_render/__init__.py
new file mode 100644
index 00000000..d1dd4446
--- /dev/null
+++ b/dm_control/_render/__init__.py
@@ -0,0 +1,109 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""OpenGL context management for rendering MuJoCo scenes.
+
+By default, the `Renderer` class will try to load one of the following rendering
+APIs, in descending order of priority: GLFW > EGL > OSMesa.
+
+It is also possible to select a specific backend by setting the `MUJOCO_GL=`
+environment variable to 'glfw', 'egl', or 'osmesa'.
+"""
+
+import collections
+import os
+
+from absl import logging
+from dm_control._render import constants
+
+BACKEND = os.environ.get(constants.MUJOCO_GL)
+
+
+# pylint: disable=g-import-not-at-top
+def _import_egl():
+ from dm_control._render.pyopengl.egl_renderer import EGLContext
+ return EGLContext
+
+
+def _import_glfw():
+ from dm_control._render.glfw_renderer import GLFWContext
+ return GLFWContext
+
+
+def _import_osmesa():
+ from dm_control._render.pyopengl.osmesa_renderer import OSMesaContext
+ return OSMesaContext
+
+
+# Import removed.
+# pylint: enable=g-import-not-at-top
+
+
+def _no_renderer():
+ def no_renderer(*args, **kwargs):
+ del args, kwargs
+ raise RuntimeError('No OpenGL rendering backend is available.')
+ return no_renderer
+
+
+_ALL_RENDERERS = (
+ (constants.GLFW, _import_glfw),
+ (constants.EGL, _import_egl),
+ (constants.OSMESA, _import_osmesa),
+ # Option removed.
+)
+
+_NO_RENDERER = (
+ (constants.NO_RENDERER, _no_renderer),
+)
+
+
+if BACKEND is not None:
+ # If a backend was specified, try importing it and error if unsuccessful.
+ import_func = None
+ for names, importer in _ALL_RENDERERS + _NO_RENDERER:
+ if BACKEND in names:
+ import_func = importer
+ BACKEND = names[0] # canonicalize the renderer name
+ break
+ if import_func is None:
+ all_names = set()
+ for names, _ in _ALL_RENDERERS + _NO_RENDERER:
+ all_names.update(names)
+ raise RuntimeError(
+ 'Environment variable {} must be one of {!r}: got {!r}.'
+ .format(constants.MUJOCO_GL, sorted(all_names), BACKEND))
+ logging.info('MUJOCO_GL=%s, attempting to import specified OpenGL backend.',
+ BACKEND)
+ Renderer = import_func()
+else:
+ logging.info('MUJOCO_GL is not set, so an OpenGL backend will be chosen '
+ 'automatically.')
+ # Otherwise try importing them in descending order of priority until
+ # successful.
+ for names, import_func in _ALL_RENDERERS:
+ try:
+ Renderer = import_func()
+ BACKEND = names[0]
+ logging.info('Successfully imported OpenGL backend: %s', names[0])
+ break
+ except ImportError:
+ logging.info('Failed to import OpenGL backend: %s', names[0])
+ if BACKEND is None:
+ logging.info('No OpenGL backend could be imported. Attempting to create a '
+ 'rendering context will result in a RuntimeError.')
+ Renderer = _no_renderer()
+
+USING_GPU = BACKEND in constants.EGL + constants.GLFW
diff --git a/dm_control/_render/base.py b/dm_control/_render/base.py
new file mode 100644
index 00000000..66ea02f6
--- /dev/null
+++ b/dm_control/_render/base.py
@@ -0,0 +1,166 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for OpenGL context handlers.
+
+`ContextBase` defines a common API that OpenGL rendering contexts should conform
+to. In addition, it provides a `make_current` context manager that:
+
+1. Makes this OpenGL context current within the appropriate rendering thread.
+2. Yields an object exposing a `call` method that can be used to execute OpenGL
+ calls within the rendering thread.
+
+See the docstring for `dm_control.utils.render_executor` for further details
+regarding rendering threads.
+"""
+
+import abc
+import atexit
+import collections
+import contextlib
+import sys
+import weakref
+
+from absl import logging
+from dm_control._render import executor
+import numpy as np
+
+
+_CURRENT_CONTEXT_FOR_THREAD = collections.defaultdict(lambda: None)
+_CURRENT_THREAD_FOR_CONTEXT = collections.defaultdict(lambda: None)
+
+
+class ContextBase(metaclass=abc.ABCMeta):
+ """Base class for managing OpenGL contexts."""
+
+ def __init__(self,
+ max_width,
+ max_height,
+ render_executor_class=executor.RenderExecutor):
+ """Initializes this context."""
+ logging.debug('Using render executor class: %s',
+ render_executor_class.__name__)
+ self._render_executor = render_executor_class()
+ self._refcount = 0
+
+ self_weakref = weakref.ref(self)
+ def _free_at_exit():
+ if self_weakref():
+ self_weakref()._free_unconditionally() # pylint: disable=protected-access
+ atexit.register(_free_at_exit)
+
+ with self._render_executor.execution_context() as ctx:
+ ctx.call(self._platform_init, max_width, max_height)
+
+ self._patients = []
+
+ def keep_alive(self, obj):
+ self._patients.append(obj)
+
+ def dont_keep_alive(self, obj):
+ try:
+ self._patients.remove(obj)
+ except ValueError:
+ pass
+
+ def increment_refcount(self):
+ self._refcount += 1
+
+ def decrement_refcount(self):
+ self._refcount -= 1
+
+ @property
+ def terminated(self):
+ return self._render_executor.terminated
+
+ @property
+ def thread(self):
+ return self._render_executor.thread
+
+ def _free_on_executor_thread(self): # pylint: disable=missing-function-docstring
+ current_ctx = _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread]
+ if current_ctx is not None:
+ del _CURRENT_THREAD_FOR_CONTEXT[current_ctx]
+ del _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread]
+
+ self._platform_make_current()
+
+ try:
+ dummy = []
+ while self._patients:
+ patient = self._patients.pop()
+ assert sys.getrefcount(patient) == sys.getrefcount(dummy)
+ if hasattr(patient, 'free'):
+ patient.free()
+ del patient
+ finally:
+ self._platform_free()
+
+ def free(self):
+ """Frees resources associated with this context if its refcount is zero."""
+ if self._refcount == 0:
+ self._free_unconditionally()
+
+ def _free_unconditionally(self):
+ self._render_executor.terminate(self._free_on_executor_thread)
+
+ def __del__(self):
+ self._free_unconditionally()
+
+ @contextlib.contextmanager
+ def make_current(self):
+ """Context manager that makes this Renderer's OpenGL context current.
+
+ Yields:
+ An object that exposes a `call` method that can be used to call a
+ function on the dedicated rendering thread.
+
+ Raises:
+ RuntimeError: If this context is already current on another thread.
+ """
+
+ with self._render_executor.execution_context() as ctx:
+ if _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] != id(self):
+ if _CURRENT_THREAD_FOR_CONTEXT[id(self)]:
+ raise RuntimeError(
+ 'Cannot make context {!r} current on thread {!r}: '
+ 'this context is already current on another thread {!r}.'
+ .format(self, self._render_executor.thread,
+ _CURRENT_THREAD_FOR_CONTEXT[id(self)]))
+ else:
+ current_context = (
+ _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread])
+ if current_context:
+ del _CURRENT_THREAD_FOR_CONTEXT[current_context]
+ _CURRENT_THREAD_FOR_CONTEXT[id(self)] = self._render_executor.thread
+ _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] = id(self)
+ ctx.call(self._platform_make_current)
+ yield ctx
+
+ def to_pixels(self, buffer):
+ """Converts the buffer to pixels."""
+ return np.flipud(buffer)
+
+ @abc.abstractmethod
+ def _platform_init(self, max_width, max_height):
+ """Performs an implementation-specific context initialization."""
+
+ @abc.abstractmethod
+ def _platform_make_current(self):
+ """Make the OpenGL context current on the executing thread."""
+
+ @abc.abstractmethod
+ def _platform_free(self):
+ """Performs an implementation-specific context cleanup."""
diff --git a/dm_control/_render/base_test.py b/dm_control/_render/base_test.py
new file mode 100644
index 00000000..4caa4bf7
--- /dev/null
+++ b/dm_control/_render/base_test.py
@@ -0,0 +1,162 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the base rendering module."""
+
+import threading
+from absl.testing import absltest
+from dm_control._render import base
+from dm_control._render import executor
+
+WIDTH = 1024
+HEIGHT = 768
+
+
+class ContextBaseTests(absltest.TestCase):
+
+ class ContextMock(base.ContextBase):
+
+ def _platform_init(self, max_width, max_height):
+ self.init_thread = threading.current_thread()
+ self.make_current_count = 0
+ self.max_width = max_width
+ self.max_height = max_height
+ self.free_thread = None
+
+ def _platform_make_current(self):
+ self.make_current_count += 1
+ self.make_current_thread = threading.current_thread()
+
+ def _platform_free(self):
+ self.free_thread = threading.current_thread()
+
+ def setUp(self):
+ super().setUp()
+ self.context = ContextBaseTests.ContextMock(WIDTH, HEIGHT)
+
+ def test_init(self):
+ self.assertIs(self.context.init_thread, self.context.thread)
+ self.assertEqual(self.context.max_width, WIDTH)
+ self.assertEqual(self.context.max_height, HEIGHT)
+
+ def test_make_current(self):
+ self.assertEqual(self.context.make_current_count, 0)
+
+ with self.context.make_current():
+ pass
+ self.assertEqual(self.context.make_current_count, 1)
+ self.assertIs(self.context.make_current_thread, self.context.thread)
+
+ # Already current, shouldn't trigger a call to `_platform_make_current`.
+ with self.context.make_current():
+ pass
+ self.assertEqual(self.context.make_current_count, 1)
+ self.assertIs(self.context.make_current_thread, self.context.thread)
+
+ def test_thread_sharing(self):
+ first_context = ContextBaseTests.ContextMock(
+ WIDTH, HEIGHT, executor.PassthroughRenderExecutor)
+ second_context = ContextBaseTests.ContextMock(
+ WIDTH, HEIGHT, executor.PassthroughRenderExecutor)
+
+ with first_context.make_current():
+ pass
+ self.assertEqual(first_context.make_current_count, 1)
+
+ with first_context.make_current():
+ pass
+ self.assertEqual(first_context.make_current_count, 1)
+
+ with second_context.make_current():
+ pass
+ self.assertEqual(second_context.make_current_count, 1)
+
+ with second_context.make_current():
+ pass
+ self.assertEqual(second_context.make_current_count, 1)
+
+ with first_context.make_current():
+ pass
+ self.assertEqual(first_context.make_current_count, 2)
+
+ with second_context.make_current():
+ pass
+ self.assertEqual(second_context.make_current_count, 2)
+
+ def test_free(self):
+ with self.context.make_current():
+ pass
+
+ thread = self.context.thread
+ self.assertIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT)
+ self.assertIn(thread, base._CURRENT_CONTEXT_FOR_THREAD)
+
+ self.context.free()
+ self.assertIs(self.context.free_thread, thread)
+ self.assertIsNone(self.context.thread)
+
+ self.assertNotIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT)
+ self.assertNotIn(thread, base._CURRENT_CONTEXT_FOR_THREAD)
+
+ def test_free_with_multiple_contexts(self):
+ context1 = ContextBaseTests.ContextMock(WIDTH, HEIGHT,
+ executor.PassthroughRenderExecutor)
+ with context1.make_current():
+ pass
+
+ context2 = ContextBaseTests.ContextMock(WIDTH, HEIGHT,
+ executor.PassthroughRenderExecutor)
+ with context2.make_current():
+ pass
+
+ self.assertEqual(base._CURRENT_CONTEXT_FOR_THREAD[threading.main_thread()],
+ id(context2))
+ self.assertIs(base._CURRENT_THREAD_FOR_CONTEXT[id(context2)],
+ threading.main_thread())
+
+ context1.free()
+ self.assertIsNone(
+ base._CURRENT_CONTEXT_FOR_THREAD[self.context.free_thread])
+ self.assertIsNone(base._CURRENT_THREAD_FOR_CONTEXT[id(context2)])
+
+ def test_refcounting(self):
+ thread = self.context.thread
+
+ self.assertEqual(self.context._refcount, 0)
+ self.context.increment_refcount()
+ self.assertEqual(self.context._refcount, 1)
+
+ # Context should not be freed yet, since its refcount is still positive.
+ self.context.free()
+ self.assertIsNone(self.context.free_thread)
+ self.assertIs(self.context.thread, thread)
+
+ # Decrement the refcount to zero.
+ self.context.decrement_refcount()
+ self.assertEqual(self.context._refcount, 0)
+
+ # Now the context can be freed.
+ self.context.free()
+ self.assertIs(self.context.free_thread, thread)
+ self.assertIsNone(self.context.thread)
+
+ def test_del(self):
+ self.assertIsNone(self.context.free_thread)
+ self.context.__del__()
+ self.assertIsNotNone(self.context.free_thread)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/_render/constants.py b/dm_control/_render/constants.py
new file mode 100644
index 00000000..4d9cee65
--- /dev/null
+++ b/dm_control/_render/constants.py
@@ -0,0 +1,33 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""String constants for the rendering module."""
+
+# Name of the environment variable that selects a renderer platform.
+MUJOCO_GL = 'MUJOCO_GL'
+
+# Name of the environment variable that selects a platform for PyOpenGL.
+PYOPENGL_PLATFORM = 'PYOPENGL_PLATFORM'
+
+# Renderer platform specifiers.
+# All values in each tuple are synonyms for the MUJOCO_GL environment variable.
+# The first entry in each tuple is considered "canonical", and is the one
+# assigned to the _render.BACKEND variable.
+OSMESA = ('osmesa',)
+GLFW = ('glfw', 'on', 'enable', 'enabled', 'true', '1', '')
+EGL = ('egl',)
+# Constant removed.
+NO_RENDERER = ('off', 'disable', 'disabled', 'false', '0')
+
diff --git a/dm_control/_render/executor/__init__.py b/dm_control/_render/executor/__init__.py
new file mode 100644
index 00000000..12a9c48f
--- /dev/null
+++ b/dm_control/_render/executor/__init__.py
@@ -0,0 +1,51 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""RenderExecutor executes OpenGL rendering calls on an appropriate thread.
+
+OpenGL calls must be made on the same thread as where an OpenGL context is
+made current on. With GPU rendering, migrating OpenGL contexts between threads
+can become expensive. We provide a thread-safe executor that maintains a
+thread on which an OpenGL context can be kept permanently current, and any other
+threads that wish to render with that context will have their rendering calls
+offloaded to the dedicated thread.
+
+For single-threaded applications, set the `DISABLE_RENDER_THREAD_OFFLOADING`
+environment variable before launching the Python interpreter. This will
+eliminate the overhead of unnecessary thread-switching.
+"""
+
+# pylint: disable=g-import-not-at-top
+import os
+_OFFLOAD = not bool(os.environ.get('DISABLE_RENDER_THREAD_OFFLOADING', ''))
+del os
+
+from dm_control._render.executor.render_executor import BaseRenderExecutor
+from dm_control._render.executor.render_executor import OffloadingRenderExecutor
+from dm_control._render.executor.render_executor import PassthroughRenderExecutor
+
+_EXECUTORS = (PassthroughRenderExecutor, OffloadingRenderExecutor)
+
+try:
+ from dm_control._render.executor.native_mutex.render_executor import NativeMutexOffloadingRenderExecutor
+ _EXECUTORS += (NativeMutexOffloadingRenderExecutor,)
+except ImportError:
+ NativeMutexOffloadingRenderExecutor = None
+
+if _OFFLOAD:
+ RenderExecutor = ( # pylint: disable=invalid-name
+ NativeMutexOffloadingRenderExecutor or OffloadingRenderExecutor)
+else:
+ RenderExecutor = PassthroughRenderExecutor # pylint: disable=invalid-name
diff --git a/dm_control/_render/executor/render_executor.py b/dm_control/_render/executor/render_executor.py
new file mode 100644
index 00000000..4592e19a
--- /dev/null
+++ b/dm_control/_render/executor/render_executor.py
@@ -0,0 +1,218 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""RenderExecutors executes OpenGL rendering calls on an appropriate thread.
+
+The purpose of these classes is to ensure that OpenGL calls are made on the
+same thread as where an OpenGL context was made current.
+
+In a single-threaded setting, `PassthroughRenderExecutor` is essentially a no-op
+that executes rendering calls on the same thread. This is provided to minimize
+thread-switching overhead.
+
+In a multithreaded setting, `OffloadingRenderExecutor` maintains a separate
+dedicated thread on which the OpenGL context is created and made current. All
+subsequent rendering calls are then offloaded onto this dedicated thread.
+"""
+
+import abc
+import collections
+from concurrent import futures
+import contextlib
+import threading
+
+
+_NOT_IN_CONTEXT = 'Cannot be called outside of an `execution_context`.'
+_ALREADY_TERMINATED = 'This executor has already been terminated.'
+
+
+class _FakeLock:
+ """An object with the same API as `threading.Lock` but that does nothing."""
+
+ def acquire(self, blocking=True):
+ pass
+
+ def release(self):
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ del exc_type, exc_value, traceback
+
+
+_FAKE_LOCK = _FakeLock()
+
+
+class BaseRenderExecutor(metaclass=abc.ABCMeta):
+ """An object that manages rendering calls for an OpenGL context.
+
+ This class helps ensure that OpenGL calls are made on the correct thread. The
+ usage pattern is as follows:
+
+ ```python
+ executor = SomeRenderExecutorClass()
+ with executor.execution_context() as ctx:
+ ctx.call(an_opengl_call, arg, kwarg=foo)
+ result = ctx.call(another_opengl_call)
+ ```
+ """
+
+ def __init__(self):
+ self._locked = 0
+ self._terminated = False
+
+ def _check_locked(self):
+ if not self._locked:
+ raise RuntimeError(_NOT_IN_CONTEXT)
+
+ def _check_not_terminated(self):
+ if self._terminated:
+ raise RuntimeError(_ALREADY_TERMINATED)
+
+ @contextlib.contextmanager
+ def execution_context(self):
+ """A context manager that allows calls to be offloaded to this executor."""
+ self._check_not_terminated()
+ with self._lock_if_necessary:
+ self._locked += 1
+ yield self
+ self._locked -= 1
+
+ @property
+ def terminated(self):
+ return self._terminated
+
+ @property
+ @abc.abstractmethod
+ def thread(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def _lock_if_necessary(self):
+ pass
+
+ @abc.abstractmethod
+ def call(self, *args, **kwargs):
+ pass
+
+ @abc.abstractmethod
+ def terminate(self, cleanup_callable=None):
+ pass
+
+
+class PassthroughRenderExecutor(BaseRenderExecutor):
+ """A no-op render executor that executes on the calling thread."""
+
+ def __init__(self):
+ super().__init__()
+ self._mutex = threading.RLock()
+
+ @property
+ def thread(self):
+ if not self._terminated:
+ return threading.current_thread()
+ else:
+ return None
+
+ @property
+ def _lock_if_necessary(self):
+ return self._mutex
+
+ def call(self, func, *args, **kwargs):
+ self._check_locked()
+ return func(*args, **kwargs)
+
+ def terminate(self, cleanup_callable=None):
+ with self._lock_if_necessary:
+ if not self._terminated:
+ if cleanup_callable:
+ cleanup_callable()
+ self._terminated = True
+
+
+class _ThreadPoolExecutorPool:
+ """A pool of reusable ThreadPoolExecutors."""
+
+ def __init__(self):
+ self._deque = collections.deque()
+ self._lock = threading.Lock()
+
+ def acquire(self):
+ with self._lock:
+ if self._deque:
+ return self._deque.popleft()
+ else:
+ return futures.ThreadPoolExecutor(max_workers=1)
+
+ def release(self, thread_pool_executor):
+ with self._lock:
+ self._deque.append(thread_pool_executor)
+
+
+_THREAD_POOL_EXECUTOR_POOL = _ThreadPoolExecutorPool()
+
+
+class OffloadingRenderExecutor(BaseRenderExecutor):
+ """A render executor that executes calls on a dedicated offload thread."""
+
+ def __init__(self):
+ super().__init__()
+ self._mutex = threading.RLock()
+ self._executor = _THREAD_POOL_EXECUTOR_POOL.acquire()
+ self._thread = self._executor.submit(threading.current_thread).result()
+
+ @property
+ def thread(self):
+ return self._thread
+
+ @property
+ def _lock_if_necessary(self):
+ if threading.current_thread() is self.thread:
+ # If the offload thread needs to make a call to its own executor, for
+ # example when a weakref callback is triggered during an offloaded call,
+ # then we must not try to reacquire our own lock.
+ # Otherwise, a deadlock ensues.
+ return _FAKE_LOCK
+ else:
+ return self._mutex
+
+ def call(self, func, *args, **kwargs):
+ self._check_locked()
+ return self._call_locked(func, *args, **kwargs)
+
+ def _call_locked(self, func, *args, **kwargs):
+ if threading.current_thread() is self.thread:
+ # If the offload thread needs to make a call to its own executor, for
+ # example when a weakref callback is triggered during an offloaded call,
+ # we should just directly call the function.
+ # Otherwise, a deadlock ensues.
+ return func(*args, **kwargs)
+ else:
+ return self._executor.submit(func, *args, **kwargs).result()
+
+ def terminate(self, cleanup_callable=None):
+ if self._terminated:
+ return
+ with self._lock_if_necessary:
+ if not self._terminated:
+ if cleanup_callable:
+ self._call_locked(cleanup_callable)
+ _THREAD_POOL_EXECUTOR_POOL.release(self._executor)
+ self._executor = None
+ self._thread = None
+ self._terminated = True
diff --git a/dm_control/_render/executor/render_executor_test.py b/dm_control/_render/executor/render_executor_test.py
new file mode 100644
index 00000000..db4dc113
--- /dev/null
+++ b/dm_control/_render/executor/render_executor_test.py
@@ -0,0 +1,205 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.utils.render_executor."""
+
+import threading
+import time
+import unittest
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control._render import executor
+import mock
+
+
+def enforce_timeout(timeout):
+ def wrap(test_func):
+ def wrapped_test(self, *args, **kwargs):
+ thread = threading.Thread(
+ target=test_func, args=((self,) + args), kwargs=kwargs)
+ thread.daemon = True
+ thread.start()
+ thread.join(timeout=timeout)
+ self.assertFalse(
+ thread.is_alive(),
+ msg='Test timed out after {} seconds.'.format(timeout))
+ return wrapped_test
+ return wrap
+
+
+class RenderExecutorTest(parameterized.TestCase):
+
+ def _make_executor(self, executor_type):
+ if (executor_type == executor.NativeMutexOffloadingRenderExecutor and
+ executor_type is None):
+ raise unittest.SkipTest(
+ 'NativeMutexOffloadingRenderExecutor is not available.')
+ else:
+ return executor_type()
+
+ def test_passthrough_executor_thread(self):
+ render_executor = self._make_executor(executor.PassthroughRenderExecutor)
+ self.assertIs(render_executor.thread, threading.current_thread())
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_offloading_executor_thread(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ self.assertIsNot(render_executor.thread, threading.current_thread())
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_call_on_correct_thread(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ with render_executor.execution_context() as ctx:
+ actual_executed_thread = ctx.call(threading.current_thread)
+ self.assertIs(actual_executed_thread, render_executor.thread)
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_multithreaded(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ list_length = 5
+ shared_list = [None] * list_length
+
+ def fill_list(thread_idx):
+ def assign_value(i):
+ shared_list[i] = thread_idx
+ for _ in range(1000):
+ with render_executor.execution_context() as ctx:
+ for i in range(list_length):
+ ctx.call(assign_value, i)
+ # Other threads should be prevented from calling `assign_value` while
+ # this thread is inside the `execution_context`.
+ self.assertEqual(shared_list, [thread_idx] * list_length)
+
+ threads = [threading.Thread(target=fill_list, args=(i,)) for i in range(9)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_exception(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ message = 'fake error'
+ def raise_value_error():
+ raise ValueError(message)
+ with render_executor.execution_context() as ctx:
+ with self.assertRaisesWithLiteralMatch(ValueError, message):
+ ctx.call(raise_value_error)
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_terminate(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ cleanup = mock.MagicMock()
+ render_executor.terminate(cleanup)
+ cleanup.assert_called_once_with()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_call_outside_of_context(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ func = mock.MagicMock()
+ with self.assertRaisesWithLiteralMatch(
+ RuntimeError, executor.render_executor._NOT_IN_CONTEXT):
+ render_executor.call(func)
+ # Also test that the locked flag is properly cleared when leaving a context.
+ with render_executor.execution_context():
+ render_executor.call(lambda: None)
+ with self.assertRaisesWithLiteralMatch(
+ RuntimeError, executor.render_executor._NOT_IN_CONTEXT):
+ render_executor.call(func)
+ func.assert_not_called()
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_call_after_terminate(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ render_executor.terminate()
+ func = mock.MagicMock()
+ with self.assertRaisesWithLiteralMatch(
+ RuntimeError, executor.render_executor._ALREADY_TERMINATED):
+ with render_executor.execution_context() as ctx:
+ ctx.call(func)
+ func.assert_not_called()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_locking(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ other_thread_context_entered = threading.Condition()
+ other_thread_context_done = [False]
+ def other_thread_func():
+ with render_executor.execution_context():
+ with other_thread_context_entered:
+ other_thread_context_entered.notify()
+ time.sleep(1)
+ other_thread_context_done[0] = True
+ other_thread = threading.Thread(target=other_thread_func)
+ with other_thread_context_entered:
+ other_thread.start()
+ other_thread_context_entered.wait()
+ with render_executor.execution_context():
+ self.assertTrue(
+ other_thread_context_done[0],
+ msg=('Main thread should not be able to enter the execution context '
+ 'until the other thread is done.'))
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ @enforce_timeout(timeout=5.)
+ def test_reentrant_locking(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ def triple_lock(render_executor):
+ with render_executor.execution_context():
+ with render_executor.execution_context():
+ with render_executor.execution_context():
+ pass
+ triple_lock(render_executor)
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ @enforce_timeout(timeout=5.)
+ def test_no_deadlock_in_callbacks(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ # This test times out in the event of a deadlock.
+ def callback():
+ with render_executor.execution_context() as ctx:
+ ctx.call(lambda: None)
+ with render_executor.execution_context() as ctx:
+ ctx.call(callback)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/render/glfw_renderer.py b/dm_control/_render/glfw_renderer.py
similarity index 57%
rename from dm_control/render/glfw_renderer.py
rename to dm_control/_render/glfw_renderer.py
index 3ecf684b..301e40a0 100644
--- a/dm_control/render/glfw_renderer.py
+++ b/dm_control/_render/glfw_renderer.py
@@ -15,72 +15,54 @@
"""An OpenGL renderer backed by GLFW."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-# Internal dependencies.
-
-from dm_control.render import base
-import six
+from dm_control._render import base
+from dm_control._render import executor
# Re-raise any exceptions that occur during module import as `ImportError`s.
# This simplifies the conditional imports in `render/__init__.py`.
try:
import glfw # pylint: disable=g-import-not-at-top
except (ImportError, IOError, OSError) as exc:
- _, exc, tb = sys.exc_info()
- six.reraise(ImportError, ImportError(str(exc)), tb)
+ raise ImportError from exc
try:
glfw.init()
except glfw.GLFWError as exc:
- _, exc, tb = sys.exc_info()
- six.reraise(ImportError, ImportError(str(exc)), tb)
+ raise ImportError from exc
class GLFWContext(base.ContextBase):
"""An OpenGL context backed by GLFW."""
def __init__(self, max_width, max_height):
+ # GLFWContext always uses `PassthroughRenderExecutor` rather than offloading
+ # rendering calls to a separate thread because GLFW can only be safely used
+ # from the main thread.
+ super().__init__(max_width, max_height, executor.PassthroughRenderExecutor)
+
+ def _platform_init(self, max_width, max_height):
"""Initializes this context.
Args:
max_width: Integer specifying the maximum framebuffer width in pixels.
max_height: Integer specifying the maximum framebuffer height in pixels.
"""
- super(GLFWContext, self).__init__()
glfw.window_hint(glfw.VISIBLE, 0)
glfw.window_hint(glfw.DOUBLEBUFFER, 0)
self._context = glfw.create_window(width=max_width, height=max_height,
title='Invisible window', monitor=None,
share=None)
- self._previous_context = None
- # This reference prevents `glfw` from being garbage-collected before the
- # last window is destroyed, otherwise we may get `AttributeError`s when the
- # `__del__` method is later called.
- self._glfw = glfw
+ # This reference prevents `glfw.destroy_window` from being garbage-collected
+ # before the last window is destroyed, otherwise we may get
+ # `AttributeError`s when the `__del__` method is later called.
+ self._destroy_window = glfw.destroy_window
- def activate(self, width, height):
- """Called when entering the `make_current` context manager.
-
- Args:
- width: Integer specifying the new framebuffer width in pixels.
- height: Integer specifying the new framebuffer height in pixels.
- """
- self._previous_context = glfw.get_current_context()
+ def _platform_make_current(self):
glfw.make_context_current(self._context)
- if (width, height) != glfw.get_window_size(self._context):
- glfw.set_window_size(self._context, width, height)
-
- def deactivate(self):
- """Called when exiting the `make_current` context manager."""
- glfw.make_context_current(self._previous_context)
- def _free(self):
+ def _platform_free(self):
"""Frees resources associated with this context."""
- self._previous_context = None
if self._context:
- glfw.destroy_window(self._context)
+ if glfw.get_current_context() == self._context:
+ glfw.make_context_current(None)
+ self._destroy_window(self._context)
self._context = None
diff --git a/dm_control/_render/glfw_renderer_test.py b/dm_control/_render/glfw_renderer_test.py
new file mode 100644
index 00000000..b599690b
--- /dev/null
+++ b/dm_control/_render/glfw_renderer_test.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for GLFWContext."""
+
+import unittest
+from absl.testing import absltest
+from dm_control import _render
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.testing import decorators
+
+import mock # pylint: disable=g-import-not-at-top
+
+MAX_WIDTH = 1024
+MAX_HEIGHT = 1024
+
+CONTEXT_PATH = _render.__name__ + '.glfw_renderer.glfw'
+
+
+@unittest.skipUnless(
+ _render.BACKEND == _render.constants.GLFW[0],
+ reason='GLFW beckend not selected.')
+class GLFWContextTest(absltest.TestCase):
+
+ def test_init(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_glfw:
+ mock_glfw.create_window.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ self.assertIs(renderer._context, mock_context)
+
+ def test_make_current(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_glfw:
+ mock_glfw.create_window.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ with renderer.make_current():
+ pass
+ mock_glfw.make_context_current.assert_called_once_with(mock_context)
+
+ def test_freeing(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_glfw:
+ mock_glfw.create_window.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ renderer.free()
+ mock_glfw.destroy_window.assert_called_once_with(mock_context)
+ self.assertIsNone(renderer._context)
+
+ @decorators.run_threaded(num_threads=1, calls_per_thread=20)
+ def test_repeatedly_create_and_destroy_context(self):
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ wrapper.MjrContext(wrapper.MjModel.from_xml_string(''), renderer)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/_render/pyopengl/__init__.py b/dm_control/_render/pyopengl/__init__.py
new file mode 100644
index 00000000..a514c4bb
--- /dev/null
+++ b/dm_control/_render/pyopengl/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/_render/pyopengl/egl_ext.py b/dm_control/_render/pyopengl/egl_ext.py
new file mode 100644
index 00000000..5cfc615c
--- /dev/null
+++ b/dm_control/_render/pyopengl/egl_ext.py
@@ -0,0 +1,74 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Extends OpenGL.EGL with definitions necessary for headless rendering."""
+
+import ctypes
+from OpenGL.platform import ctypesloader # pylint: disable=g-bad-import-order
+try:
+ # Nvidia driver seems to need libOpenGL.so (as opposed to libGL.so)
+ # for multithreading to work properly. We load this in before everything else.
+ ctypesloader.loadLibrary(ctypes.cdll, 'OpenGL', mode=ctypes.RTLD_GLOBAL)
+except OSError:
+ pass
+
+# pylint: disable=g-import-not-at-top
+
+from OpenGL import EGL
+from OpenGL import error
+
+
+# From the EGL_EXT_device_enumeration extension.
+PFNEGLQUERYDEVICESEXTPROC = ctypes.CFUNCTYPE(
+ EGL.EGLBoolean,
+ EGL.EGLint,
+ ctypes.POINTER(EGL.EGLDeviceEXT),
+ ctypes.POINTER(EGL.EGLint),
+)
+try:
+ _eglQueryDevicesEXT = PFNEGLQUERYDEVICESEXTPROC( # pylint: disable=invalid-name
+ EGL.eglGetProcAddress('eglQueryDevicesEXT'))
+except TypeError:
+ raise ImportError('eglQueryDevicesEXT is not available.')
+
+
+# From the EGL_EXT_platform_device extension.
+EGL_PLATFORM_DEVICE_EXT = 0x313F
+PFNEGLGETPLATFORMDISPLAYEXTPROC = ctypes.CFUNCTYPE(
+ EGL.EGLDisplay, EGL.EGLenum, ctypes.c_void_p, ctypes.POINTER(EGL.EGLint))
+try:
+ eglGetPlatformDisplayEXT = PFNEGLGETPLATFORMDISPLAYEXTPROC( # pylint: disable=invalid-name
+ EGL.eglGetProcAddress('eglGetPlatformDisplayEXT'))
+except TypeError:
+ raise ImportError('eglGetPlatformDisplayEXT is not available.')
+
+
+# Wrap raw _eglQueryDevicesEXT function into something more Pythonic.
+def eglQueryDevicesEXT(max_devices=10): # pylint: disable=invalid-name
+ devices = (EGL.EGLDeviceEXT * max_devices)()
+ num_devices = EGL.EGLint()
+ success = _eglQueryDevicesEXT(max_devices, devices, num_devices)
+ if success == EGL.EGL_TRUE:
+ return [devices[i] for i in range(num_devices.value)]
+ else:
+ raise error.GLError(err=EGL.eglGetError(),
+ baseOperation=eglQueryDevicesEXT,
+ result=success)
+
+
+# Expose everything from upstream so that
+# we can use this as a drop-in replacement for OpenGL.EGL.
+# pylint: disable=wildcard-import,g-bad-import-order
+from OpenGL.EGL import *
diff --git a/dm_control/_render/pyopengl/egl_renderer.py b/dm_control/_render/pyopengl/egl_renderer.py
new file mode 100644
index 00000000..b9dfa149
--- /dev/null
+++ b/dm_control/_render/pyopengl/egl_renderer.py
@@ -0,0 +1,140 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An OpenGL renderer backed by EGL, provided through PyOpenGL."""
+
+import atexit
+import ctypes
+import os
+
+from dm_control._render import base
+from dm_control._render import constants
+from dm_control._render import executor
+
+PYOPENGL_PLATFORM = os.environ.get(constants.PYOPENGL_PLATFORM)
+
+if not PYOPENGL_PLATFORM:
+ os.environ[constants.PYOPENGL_PLATFORM] = constants.EGL[0]
+elif PYOPENGL_PLATFORM != constants.EGL[0]:
+ raise ImportError(
+ 'Cannot use EGL rendering platform. '
+ 'The PYOPENGL_PLATFORM environment variable is set to {!r} '
+ '(should be either unset or {!r}).'
+ .format(PYOPENGL_PLATFORM, constants.EGL[0]))
+
+
+# pylint: disable=g-import-not-at-top
+from dm_control._render.pyopengl import egl_ext as EGL
+from OpenGL import error
+
+
+def create_initialized_headless_egl_display():
+ """Creates an initialized EGL display directly on a device."""
+ all_devices = EGL.eglQueryDevicesEXT()
+ selected_device = os.environ.get('MUJOCO_EGL_DEVICE_ID', None)
+ if selected_device is None:
+ candidates = all_devices
+ else:
+ device_idx = int(selected_device)
+ if not 0 <= device_idx < len(all_devices):
+ raise RuntimeError(
+ f'MUJOCO_EGL_DEVICE_ID must be an integer between 0 and '
+ f'{len(all_devices) - 1} (inclusive), got {device_idx}.')
+ candidates = all_devices[device_idx:device_idx + 1]
+ for device in candidates:
+ display = EGL.eglGetPlatformDisplayEXT(
+ EGL.EGL_PLATFORM_DEVICE_EXT, device, None)
+ if display != EGL.EGL_NO_DISPLAY and EGL.eglGetError() == EGL.EGL_SUCCESS:
+ # `eglInitialize` may or may not raise an exception on failure depending
+ # on how PyOpenGL is configured. We therefore catch a `GLError` and also
+ # manually check the output of `eglGetError()` here.
+ try:
+ initialized = EGL.eglInitialize(display, None, None)
+ except error.GLError:
+ pass
+ else:
+ if initialized == EGL.EGL_TRUE and EGL.eglGetError() == EGL.EGL_SUCCESS:
+ return display
+ return EGL.EGL_NO_DISPLAY
+
+
+EGL_DISPLAY = create_initialized_headless_egl_display()
+if EGL_DISPLAY == EGL.EGL_NO_DISPLAY:
+ raise ImportError('Cannot initialize a headless EGL display.')
+atexit.register(EGL.eglTerminate, EGL_DISPLAY)
+
+
+EGL_ATTRIBUTES = (
+ EGL.EGL_RED_SIZE, 8,
+ EGL.EGL_GREEN_SIZE, 8,
+ EGL.EGL_BLUE_SIZE, 8,
+ EGL.EGL_ALPHA_SIZE, 8,
+ EGL.EGL_DEPTH_SIZE, 24,
+ EGL.EGL_STENCIL_SIZE, 8,
+ EGL.EGL_COLOR_BUFFER_TYPE, EGL.EGL_RGB_BUFFER,
+ EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
+ EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_BIT,
+ EGL.EGL_NONE
+)
+
+
+class EGLContext(base.ContextBase):
+ """An OpenGL context backed by EGL."""
+
+ def __init__(self, max_width, max_height):
+ # EGLContext currently only works with `PassthroughRenderExecutor`.
+ # TODO(b/110927854) Make this work with the offloading executor.
+ self._context = None
+ super().__init__(max_width, max_height, executor.PassthroughRenderExecutor)
+
+ def _platform_init(self, unused_max_width, unused_max_height):
+ """Initialization this EGL context."""
+ num_configs = ctypes.c_long(0)
+ config_size = 1
+ # ctypes syntax for making an array of length config_size.
+ configs = (EGL.EGLConfig * config_size)()
+ EGL.eglReleaseThread()
+ EGL.eglChooseConfig(
+ EGL_DISPLAY,
+ EGL_ATTRIBUTES,
+ configs,
+ config_size,
+ num_configs)
+ if num_configs.value < 1:
+ raise RuntimeError(
+ 'EGL failed to find a framebuffer configuration that matches the '
+ 'desired attributes: {}'.format(EGL_ATTRIBUTES))
+ EGL.eglBindAPI(EGL.EGL_OPENGL_API)
+ self._context = EGL.eglCreateContext(
+ EGL_DISPLAY, configs[0], EGL.EGL_NO_CONTEXT, None)
+ if not self._context:
+ raise RuntimeError('Cannot create an EGL context.')
+
+ def _platform_make_current(self):
+ if self._context:
+ success = EGL.eglMakeCurrent(
+ EGL_DISPLAY, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, self._context)
+ if not success:
+ raise RuntimeError('Failed to make the EGL context current.')
+
+ def _platform_free(self):
+ """Frees resources associated with this context."""
+ if self._context:
+ current_context = EGL.eglGetCurrentContext()
+ if current_context and self._context.address == current_context.address:
+ EGL.eglMakeCurrent(EGL_DISPLAY, EGL.EGL_NO_SURFACE,
+ EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
+ EGL.eglDestroyContext(EGL_DISPLAY, self._context)
+ self._context = None
diff --git a/dm_control/_render/pyopengl/osmesa_renderer.py b/dm_control/_render/pyopengl/osmesa_renderer.py
new file mode 100644
index 00000000..75f4bd20
--- /dev/null
+++ b/dm_control/_render/pyopengl/osmesa_renderer.py
@@ -0,0 +1,86 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An OpenGL renderer backed by OSMesa."""
+
+import os
+
+from dm_control._render import base
+from dm_control._render import constants
+
+PYOPENGL_PLATFORM = os.environ.get(constants.PYOPENGL_PLATFORM)
+
+if not PYOPENGL_PLATFORM:
+ os.environ[constants.PYOPENGL_PLATFORM] = constants.OSMESA[0]
+elif PYOPENGL_PLATFORM != constants.OSMESA[0]:
+ raise ImportError(
+ 'Cannot use OSMesa rendering platform. '
+ 'The PYOPENGL_PLATFORM environment variable is set to {!r} '
+ '(should be either unset or {!r}).'
+ .format(PYOPENGL_PLATFORM, constants.OSMESA[0]))
+
+# pylint: disable=g-import-not-at-top
+from OpenGL import GL
+from OpenGL import osmesa
+from OpenGL.GL import arrays
+
+_DEPTH_BITS = 24
+_STENCIL_BITS = 8
+_ACCUM_BITS = 0
+
+
+class OSMesaContext(base.ContextBase):
+ """An OpenGL context backed by OSMesa."""
+
+ def __init__(self, *args, **kwargs):
+ self._context = None
+ super().__init__(*args, **kwargs)
+
+ def _platform_init(self, max_width, max_height):
+ """Initializes this OSMesa context."""
+ self._context = osmesa.OSMesaCreateContextExt(
+ osmesa.OSMESA_RGBA,
+ _DEPTH_BITS,
+ _STENCIL_BITS,
+ _ACCUM_BITS,
+ None, # sharelist
+ )
+ if not self._context:
+ raise RuntimeError('Failed to create OSMesa GL context.')
+
+ self._height = max_height
+ self._width = max_width
+
+ # Allocate a buffer to render into.
+ self._buffer = arrays.GLfloatArray.zeros((max_height, max_width, 4))
+
+ def _platform_make_current(self):
+ if self._context:
+ success = osmesa.OSMesaMakeCurrent(
+ self._context,
+ self._buffer,
+ GL.GL_FLOAT,
+ self._width,
+ self._height)
+ if not success:
+ raise RuntimeError('Failed to make OSMesa context current.')
+
+ def _platform_free(self):
+ """Frees resources associated with this context."""
+ if self._context and self._context == osmesa.OSMesaGetCurrentContext():
+ osmesa.OSMesaMakeCurrent(None, None, GL.GL_FLOAT, 0, 0)
+ osmesa.OSMesaDestroyContext(self._context)
+ self._buffer = None
+ self._context = None
diff --git a/dm_control/_render/pyopengl/osmesa_renderer_test.py b/dm_control/_render/pyopengl/osmesa_renderer_test.py
new file mode 100644
index 00000000..7818d4b4
--- /dev/null
+++ b/dm_control/_render/pyopengl/osmesa_renderer_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for OSMesaContext."""
+
+import unittest
+
+from absl.testing import absltest
+from dm_control import _render
+import mock
+from OpenGL import GL
+
+MAX_WIDTH = 640
+MAX_HEIGHT = 480
+
+CONTEXT_PATH = _render.__name__ + '.pyopengl.osmesa_renderer.osmesa'
+GL_ARRAYS_PATH = _render.__name__ + '.pyopengl.osmesa_renderer.arrays'
+
+
+@unittest.skipUnless(
+ _render.BACKEND == _render.constants.OSMESA,
+ reason='OSMesa backend not selected.')
+class OSMesaContextTest(absltest.TestCase):
+
+ def test_init(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_osmesa:
+ mock_osmesa.OSMesaCreateContextExt.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ self.assertIs(renderer._context, mock_context)
+ renderer.free()
+
+ def test_make_current(self):
+ mock_context = mock.MagicMock()
+ mock_buffer = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_osmesa:
+ with mock.patch(GL_ARRAYS_PATH) as mock_glarrays:
+ mock_osmesa.OSMesaCreateContextExt.return_value = mock_context
+ mock_glarrays.GLfloatArray.zeros.return_value = mock_buffer
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ with renderer.make_current():
+ pass
+ mock_osmesa.OSMesaMakeCurrent.assert_called_once_with(
+ mock_context, mock_buffer, GL.GL_FLOAT, MAX_WIDTH, MAX_HEIGHT)
+ renderer.free()
+
+ def test_freeing(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_osmesa:
+ mock_osmesa.OSMesaCreateContextExt.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ renderer.free()
+ mock_osmesa.OSMesaDestroyContext.assert_called_once_with(mock_context)
+ self.assertIsNone(renderer._context)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/autowrap/autowrap.py b/dm_control/autowrap/autowrap.py
index 7da4378e..3aa0a96b 100644
--- a/dm_control/autowrap/autowrap.py
+++ b/dm_control/autowrap/autowrap.py
@@ -15,16 +15,21 @@
r"""Automatically generates ctypes Python bindings for MuJoCo.
-Parses mjdata.h, mjmodel.h, mjrender.h, mjvisualize.h, mjxmacro.h and mujoco.h;
+Parses the following MuJoCo header files:
+
+ mjdata.h
+ mjmodel.h
+ mjrender.h
+ mjui.h
+ mjvisualize.h
+ mjxmacro.h
+ mujoco.h;
+
generates the following Python source files:
constants.py: constants
enums.py: enums
sizes.py: size information for dynamically-shaped arrays
- types.py: ctypes declarations for structs
- wrappers.py: low-level Python wrapper classes for structs (these implement
- getter/setter methods for struct members where applicable)
- functions.py: ctypes function declarations for MuJoCo API functions
Example usage:
@@ -32,23 +37,18 @@
--output_dir=/path/to/mjbindings
"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
+import io
import os
-# Internal dependencies.
-
from absl import app
from absl import flags
from absl import logging
-
from dm_control.autowrap import binding_generator
from dm_control.autowrap import codegen_util
-import six
+_MJMODEL_H = "mjmodel.h"
+_MJXMACRO_H = "mjxmacro.h"
FLAGS = flags.FLAGS
@@ -61,18 +61,25 @@
def main(unused_argv):
- # Get the path to the xmacro header file.
- xmacro_hdr_path = None
- for path in FLAGS.header_paths:
- if path.endswith("mjxmacro.h"):
- xmacro_hdr_path = path
- break
- if xmacro_hdr_path is None:
- logging.fatal("List of inputs must contain a path to mjxmacro.h")
-
+ special_header_paths = {}
+
+ # Get the path to the mjmodel and mjxmacro header files.
+ # These header files need special handling.
+ for header in (_MJMODEL_H, _MJXMACRO_H):
+ for path in FLAGS.header_paths:
+ if path.endswith(header):
+ special_header_paths[header] = path
+ break
+ if header not in special_header_paths:
+ logging.fatal("List of inputs must contain a path to %s", header)
+
+ # Make sure mjmodel.h is parsed first, since it is included by other headers.
srcs = codegen_util.UniqueOrderedDict()
- for p in sorted(FLAGS.header_paths):
- with open(p, "r") as f:
+ sorted_header_paths = sorted(FLAGS.header_paths)
+ sorted_header_paths.remove(special_header_paths[_MJMODEL_H])
+ sorted_header_paths.insert(0, special_header_paths[_MJMODEL_H])
+ for p in sorted_header_paths:
+ with io.open(p, "r", errors="ignore") as f:
srcs[p] = f.read()
# consts_dict should be a codegen_util.UniqueOrderedDict.
@@ -86,39 +93,23 @@ def main(unused_argv):
# These are commented in `mjdata.h` but have no macros in `mjxmacro.h`.
hints_dict = codegen_util.UniqueOrderedDict({"buffer": ("nbuffer",),
- "stack": ("nstack",)})
+ "stack": ("narena",)})
parser = binding_generator.BindingGenerator(
consts_dict=consts_dict, hints_dict=hints_dict)
# Parse enums.
- for pth, src in six.iteritems(srcs):
- if pth is not xmacro_hdr_path:
+ for pth, src in srcs.items():
+ if pth is not special_header_paths[_MJXMACRO_H]:
parser.parse_enums(src)
# Parse constants and type declarations.
- for pth, src in six.iteritems(srcs):
- if pth is not xmacro_hdr_path:
+ for pth, src in srcs.items():
+ if pth is not special_header_paths[_MJXMACRO_H]:
parser.parse_consts_typedefs(src)
# Get shape hints from mjxmacro.h.
- parser.parse_hints(srcs[xmacro_hdr_path])
-
- # Parse structs.
- for pth, src in six.iteritems(srcs):
- if pth is not xmacro_hdr_path:
- parser.parse_structs(src)
-
- # Parse functions.
- for pth, src in six.iteritems(srcs):
- if pth is not xmacro_hdr_path:
- parser.parse_functions(src)
-
- # Parse global strings and function pointers.
- for pth, src in six.iteritems(srcs):
- if pth is not xmacro_hdr_path:
- parser.parse_global_strings(src)
- parser.parse_function_pointers(src)
+ parser.parse_hints(srcs[special_header_paths[_MJXMACRO_H]])
# Create the output directory if it doesn't already exist.
if not os.path.exists(FLAGS.output_dir):
@@ -127,9 +118,6 @@ def main(unused_argv):
# Generate Python source files and write them to the output directory.
parser.write_consts(os.path.join(FLAGS.output_dir, "constants.py"))
parser.write_enums(os.path.join(FLAGS.output_dir, "enums.py"))
- parser.write_types(os.path.join(FLAGS.output_dir, "types.py"))
- parser.write_wrappers(os.path.join(FLAGS.output_dir, "wrappers.py"))
- parser.write_funcs_and_globals(os.path.join(FLAGS.output_dir, "functions.py"))
parser.write_index_dict(os.path.join(FLAGS.output_dir, "sizes.py"))
if __name__ == "__main__":
diff --git a/dm_control/autowrap/binding_generator.py b/dm_control/autowrap/binding_generator.py
index 023f71c8..370a543e 100644
--- a/dm_control/autowrap/binding_generator.py
+++ b/dm_control/autowrap/binding_generator.py
@@ -15,46 +15,25 @@
"""Parses MuJoCo header files and generates Python bindings."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import os
import pprint
import textwrap
-# Internal dependencies.
-
from absl import logging
-
-from dm_control.autowrap import c_declarations
from dm_control.autowrap import codegen_util
from dm_control.autowrap import header_parsing
-
import pyparsing
-import six
-
-# Absolute path to the top-level module.
-_MODULE = "dm_control.mujoco.wrapper"
-# Imports used in all generated source files.
-_BOILERPLATE_IMPORTS = [
- "from __future__ import absolute_import",
- "from __future__ import division",
- "from __future__ import print_function\n",
-]
-
-class Error(Exception):
- pass
-
-
-class BindingGenerator(object):
+class BindingGenerator:
"""Parses declarations from MuJoCo headers and generates Python bindings."""
- def __init__(self, enums_dict=None, consts_dict=None, typedefs_dict=None,
- hints_dict=None, structs_dict=None, funcs_dict=None,
- strings_dict=None, func_ptrs_dict=None, index_dict=None):
+ def __init__(self,
+ enums_dict=None,
+ consts_dict=None,
+ typedefs_dict=None,
+ hints_dict=None,
+ index_dict=None):
"""Constructs a new HeaderParser instance.
The optional arguments listed below can be used to passing in dict-like
@@ -63,15 +42,11 @@ def __init__(self, enums_dict=None, consts_dict=None, typedefs_dict=None,
contents of the headers.
Args:
- enums_dict: nested mappings from {enum_name: {member_name: value}}
- consts_dict: mapping from {const_name: value}
- typedefs_dict: mapping from {type_name: ctypes_typename}
- hints_dict: mapping from {var_name: shape_tuple}
- structs_dict: mapping from {struct_name: Struct_instance}
- funcs_dict: mapping from {func_name: Function_instance}
- strings_dict: mapping from {var_name: StaticStringArray_instance}
- func_ptrs_dict: mapping from {var_name: FunctionPtr_instance}
- index_dict: mapping from {lowercase_struct_name: {var_name: shape_tuple}}
+ enums_dict: Nested mappings from {enum_name: {member_name: value}}.
+ consts_dict: Mapping from {const_name: value}.
+ typedefs_dict: Mapping from {type_name: ctypes_typename}.
+ hints_dict: Mapping from {var_name: shape_tuple}.
+ index_dict: Mapping from {lowercase_struct_name: {var_name: shape_tuple}}.
"""
self.enums_dict = (enums_dict if enums_dict is not None
else codegen_util.UniqueOrderedDict())
@@ -81,20 +56,12 @@ def __init__(self, enums_dict=None, consts_dict=None, typedefs_dict=None,
else codegen_util.UniqueOrderedDict())
self.hints_dict = (hints_dict if hints_dict is not None
else codegen_util.UniqueOrderedDict())
- self.structs_dict = (structs_dict if structs_dict is not None
- else codegen_util.UniqueOrderedDict())
- self.funcs_dict = (funcs_dict if funcs_dict is not None
- else codegen_util.UniqueOrderedDict())
- self.strings_dict = (strings_dict if strings_dict is not None
- else codegen_util.UniqueOrderedDict())
- self.func_ptrs_dict = (func_ptrs_dict if func_ptrs_dict is not None
- else codegen_util.UniqueOrderedDict())
self.index_dict = (index_dict if index_dict is not None
else codegen_util.UniqueOrderedDict())
def get_consts_and_enums(self):
consts_and_enums = self.consts_dict.copy()
- for enum in six.itervalues(self.enums_dict):
+ for enum in self.enums_dict.values():
consts_and_enums.update(enum)
return consts_and_enums
@@ -125,11 +92,19 @@ def resolve_size(self, old_size):
# If it's a string specifying a product (such as "2*mjMAXLINEPNT"),
# recursively resolve the components to ints and calculate the result.
size = 1
+ sizes = []
+ is_int = True
for part in old_size.split("*"):
dim = self.resolve_size(part)
- assert isinstance(dim, int)
- size *= dim
- return size
+ sizes.append(dim)
+ if not isinstance(dim, int):
+ is_int = False
+ else:
+ size *= dim
+ if is_int:
+ return size
+ else:
+ return tuple(sizes)
else:
# Recursively dereference any sizes declared in header macros
size = codegen_util.recursive_dict_lookup(old_size,
@@ -172,117 +147,10 @@ def resolve_typename(self, old_ctypes_typename):
new_ctypes_typename, new_ctypes_typename)
if new_ctypes_typename == old_ctypes_typename:
- logging.warn("Could not resolve typename '%s'", old_ctypes_typename)
+ logging.warning("Could not resolve typename '%s'", old_ctypes_typename)
return new_ctypes_typename
- def get_type_from_token(self, token, parent=None):
- """Accepts a token returned by a parser, returns a subclass of CDeclBase."""
-
- comment = codegen_util.mangle_comment(token.comment)
- is_const = token.is_const == "const"
-
- # A new struct declaration
- if token.members:
-
- name = token.name
-
- # If the name is empty, see if there is a type declaration that matches
- # this struct's typename
- if not name:
- for k, v in six.iteritems(self.typedefs_dict):
- if v == token.typename:
- name = k
-
- # Anonymous structs need a dummy typename
- typename = token.typename
- if not typename:
- if parent:
- typename = token.name
- else:
- raise Error(
- "Anonymous structs that aren't members of a named struct are not "
- "supported (name = '{token.name}').".format(token=token))
-
- # Mangle the name if it contains any protected keywords
- name = codegen_util.mangle_varname(name)
-
- members = codegen_util.UniqueOrderedDict()
- sub_structs = codegen_util.UniqueOrderedDict()
- out = c_declarations.Struct(name, typename, members, sub_structs, comment,
- parent, is_const)
-
- # Map the old typename to the mangled typename in typedefs_dict
- self.typedefs_dict[typename] = out.ctypes_typename
-
- # Add members
- for sub_token in token.members:
-
- # Recurse into nested structs
- member = self.get_type_from_token(sub_token, parent=out)
- out.members[member.name] = member
-
- # Nested sub-structures need special treatment
- if isinstance(member, c_declarations.Struct):
- out.sub_structs[member.name] = member
-
- # Add to dict of structs
- self.structs_dict[out.ctypes_typename] = out
-
- else:
-
- name = codegen_util.mangle_varname(token.name)
- typename = self.resolve_typename(token.typename)
-
- # 1D array with size defined at compile time
- if token.size:
- shape = self.get_shape_tuple(token.size)
- if typename in header_parsing.CTYPES_TO_NUMPY:
- out = c_declarations.StaticNDArray(name, typename, shape, comment,
- parent, is_const)
- else:
- out = c_declarations.StaticPtrArray(name, typename, shape, comment,
- parent, is_const)
- elif token.ptr:
-
- # Pointer to a numpy-compatible type, could be an array or a scalar
- if typename in header_parsing.CTYPES_TO_NUMPY:
-
- # Multidimensional array (one or more dimensions might be undefined)
- if name in self.hints_dict:
-
- # Dynamically-sized dimensions have string identifiers
- shape = self.hints_dict[name]
- if any(isinstance(d, str) for d in shape):
- out = c_declarations.DynamicNDArray(name, typename, shape,
- comment, parent, is_const)
- else:
- out = c_declarations.StaticNDArray(name, typename, shape, comment,
- parent, is_const)
-
- # This must be a pointer to a scalar primitive
- else:
- out = c_declarations.ScalarPrimitivePtr(name, typename, comment,
- parent, is_const)
-
- # Pointer to struct or other arbitrary type
- else:
- out = c_declarations.ScalarPrimitivePtr(name, typename, comment,
- parent, is_const)
-
- # A struct we've already encountered
- elif typename in self.structs_dict:
- s = self.structs_dict[typename]
- out = c_declarations.Struct(name, s.typename, s.members, s.sub_structs,
- comment, parent)
-
- # Presumably this is a scalar primitive
- else:
- out = c_declarations.ScalarPrimitive(name, typename, comment, parent,
- is_const)
-
- return out
-
# Parsing functions.
# ----------------------------------------------------------------------------
@@ -292,6 +160,8 @@ def parse_hints(self, xmacro_src):
for tokens, _, _ in parser.scanString(xmacro_src):
for xmacro in tokens:
for member in xmacro.members:
+ if not hasattr(member, "name") or not member.name:
+ continue
# "Squeeze out" singleton dimensions.
shape = self.get_shape_tuple(member.dims, squeeze=True)
self.hints_dict.update({member.name: shape})
@@ -325,7 +195,8 @@ def parse_enums(self, src):
def parse_consts_typedefs(self, src):
"""Updates self.consts_dict, self.typedefs_dict."""
- parser = (header_parsing.COND_DECL | header_parsing.UNCOND_DECL)
+ parser = (header_parsing.COND_DECL |
+ header_parsing.UNCOND_DECL)
for tokens, _, _ in parser.scanString(src):
self.recurse_into_conditionals(tokens)
@@ -353,45 +224,6 @@ def recurse_into_conditionals(self, tokens):
else:
self.consts_dict.update({token.name: True})
- def parse_structs(self, src):
- """Updates self.structs_dict."""
- parser = header_parsing.NESTED_STRUCTS
- for tokens, _, _ in parser.scanString(src):
- for token in tokens:
- self.get_type_from_token(token)
-
- def parse_functions(self, src):
- """Updates self.funcs_dict."""
- parser = header_parsing.MJAPI_FUNCTION_DECL
- for tokens, _, _ in parser.scanString(src):
- for token in tokens:
- name = codegen_util.mangle_varname(token.name)
- comment = codegen_util.mangle_comment(token.comment)
- args = codegen_util.UniqueOrderedDict()
- for arg in token.arguments:
- a = self.get_type_from_token(arg)
- args[a.name] = a
- r = self.get_type_from_token(token.return_value)
- f = c_declarations.Function(name, args, r, comment)
- self.funcs_dict[f.name] = f
-
- def parse_global_strings(self, src):
- """Updates self.strings_dict."""
- parser = header_parsing.MJAPI_STRING_ARRAY
- for token, _, _ in parser.scanString(src):
- name = codegen_util.mangle_varname(token.name)
- shape = self.get_shape_tuple(token.dims)
- self.strings_dict[name] = c_declarations.StaticStringArray(
- name, shape, symbol_name=token.name)
-
- def parse_function_pointers(self, src):
- """Updates self.func_ptrs_dict."""
- parser = header_parsing.MJAPI_FUNCTION_PTR
- for token, _, _ in parser.scanString(src):
- name = codegen_util.mangle_varname(token.name)
- self.func_ptrs_dict[name] = c_declarations.FunctionPtr(
- name, symbol_name=token.name)
-
# Code generation methods
# ----------------------------------------------------------------------------
@@ -405,21 +237,22 @@ def make_header(self, imports=()):
""".format(scriptname=os.path.split(__file__)[-1],
mujoco_version=self.consts_dict["mjVERSION_HEADER"]))
docstring = docstring[1:] # Strip the leading line break.
- return "\n".join(
- [docstring] + _BOILERPLATE_IMPORTS + list(imports) + ["\n"])
+ return "\n".join([docstring] + list(imports) + ["\n"])
def write_consts(self, fname):
+ """Write constants."""
imports = [
"# pylint: disable=invalid-name",
]
with open(fname, "w") as f:
f.write(self.make_header(imports))
f.write(codegen_util.comment_line("Constants") + "\n")
- for name, value in six.iteritems(self.consts_dict):
+ for name, value in self.consts_dict.items():
f.write("{0} = {1}\n".format(name, value))
f.write("\n" + codegen_util.comment_line("End of generated code"))
def write_enums(self, fname):
+ """Write enum definitions."""
with open(fname, "w") as f:
imports = [
"import collections",
@@ -428,9 +261,9 @@ def write_enums(self, fname):
]
f.write(self.make_header(imports))
f.write(codegen_util.comment_line("Enums"))
- for enum_name, members in six.iteritems(self.enums_dict):
- fields = ["\"{}\"".format(name) for name in six.iterkeys(members)]
- values = [str(value) for value in six.itervalues(members)]
+ for enum_name, members in self.enums_dict.items():
+ fields = ["\"{}\"".format(name) for name in members.keys()]
+ values = [str(value) for value in members.values()]
s = textwrap.dedent("""
{0} = collections.namedtuple(
"{0}",
@@ -440,76 +273,8 @@ def write_enums(self, fname):
f.write(s)
f.write("\n" + codegen_util.comment_line("End of generated code"))
- def write_types(self, fname):
- imports = [
- "import ctypes",
- ]
- with open(fname, "w") as f:
- f.write(self.make_header(imports))
- f.write(codegen_util.comment_line("ctypes struct declarations"))
- for struct in six.itervalues(self.structs_dict):
- f.write("\n" + struct.ctypes_struct_decl)
- f.write("\n" + codegen_util.comment_line("End of generated code"))
-
- def write_wrappers(self, fname):
- with open(fname, "w") as f:
- imports = [
- "import ctypes",
- "# Internal dependencies.",
- "# pylint: disable=undefined-variable",
- "# pylint: disable=wildcard-import",
- "from {} import util".format(_MODULE),
- "from {}.mjbindings.types import *".format(_MODULE),
- ]
- f.write(self.make_header(imports))
- f.write(codegen_util.comment_line("Low-level wrapper classes"))
- for struct in six.itervalues(self.structs_dict):
- f.write("\n" + struct.wrapper_class)
- f.write("\n" + codegen_util.comment_line("End of generated code"))
-
- def write_funcs_and_globals(self, fname):
- """Write ctypes declarations for functions and global data."""
- imports = [
- "import collections",
- "import ctypes",
- "# Internal dependencies.",
- "# pylint: disable=undefined-variable",
- "# pylint: disable=wildcard-import",
- "from {} import util".format(_MODULE),
- "from {}.mjbindings.types import *".format(_MODULE),
- "import numpy as np",
- "# pylint: disable=line-too-long",
- "# pylint: disable=invalid-name",
- "# common_typos_disable",
- ]
- with open(fname, "w") as f:
- f.write(self.make_header(imports))
- f.write("mjlib = util.get_mjlib()\n")
-
- f.write("\n" + codegen_util.comment_line("ctypes function declarations"))
- for function in six.itervalues(self.funcs_dict):
- f.write("\n" + function.ctypes_func_decl(cdll_name="mjlib"))
-
- # Only require strings for UI purposes.
- f.write("\n" + codegen_util.comment_line("String arrays") + "\n")
- for string_arr in six.itervalues(self.strings_dict):
- f.write(string_arr.ctypes_var_decl(cdll_name="mjlib"))
-
- f.write("\n" + codegen_util.comment_line("Function pointers"))
-
- fields = [repr(name) for name in self.func_ptrs_dict.keys()]
- values = [func_ptr.ctypes_var_decl(cdll_name="mjlib")
- for func_ptr in self.func_ptrs_dict.values()]
- f.write(textwrap.dedent("""
- function_pointers = collections.namedtuple(
- 'FunctionPointers',
- [{0}]
- )({1})
- """).format(",\n ".join(fields), ",\n ".join(values)))
-
- f.write("\n" + codegen_util.comment_line("End of generated code"))
-
def write_index_dict(self, fname):
+ """Write file containing array shape information for indexing."""
pp = pprint.PrettyPrinter()
output_string = pp.pformat(dict(self.index_dict))
indent = codegen_util.Indenter()
diff --git a/dm_control/autowrap/c_declarations.py b/dm_control/autowrap/c_declarations.py
deleted file mode 100644
index 9bdd3eb0..00000000
--- a/dm_control/autowrap/c_declarations.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Python representations of C declarations."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import textwrap
-
-# Internal dependencies.
-
-from dm_control.autowrap import codegen_util
-from dm_control.autowrap import header_parsing
-
-import six
-
-
-class CDeclBase(object):
- """Base class for Python representations of C declarations."""
-
- def __init__(self, **attrs):
- self._attrs = attrs
- for k, v in six.iteritems(attrs):
- setattr(self, k, v)
-
- def __repr__(self):
- """Pretty string representation."""
- attr_str = ", ".join("{0}={1!r}".format(k, v)
- for k, v in six.iteritems(self._attrs))
- return "{0}({1})".format(type(self).__name__, attr_str)
-
- @property
- def docstring(self):
- """Auto-generate a docstring for self."""
- return "\n".join(textwrap.wrap(self.comment, 74))
-
- @property
- def ctypes_typename(self):
- """ctypes typename."""
- return self.typename
-
- @property
- def ctypes_ptr(self):
- """String representation of self as a ctypes pointer."""
- return header_parsing.CTYPES_PTRS.get(
- self.ctypes_typename, "ctypes.POINTER({})".format(self.ctypes_typename))
-
- @property
- def np_dtype(self):
- """Get a numpy dtype name for self, fall back on self.ctypes_typename."""
- return header_parsing.CTYPES_TO_NUMPY.get(self.ctypes_typename,
- self.ctypes_typename)
-
- @property
- def np_flags(self):
- """Tuple of strings specifying numpy.ndarray flags."""
- return ("C", "W")
-
-
-class Struct(CDeclBase):
- """C struct declaration."""
-
- def __init__(self, name, typename, members, sub_structs, comment="",
- parent=None, is_const=None):
- super(Struct, self).__init__(name=name,
- typename=typename,
- members=members,
- sub_structs=sub_structs,
- comment=comment,
- parent=parent,
- is_const=is_const)
-
- @property
- def ctypes_struct_decl(self):
- """Generates a ctypes.Structure declaration for self."""
- indent = codegen_util.Indenter()
- s = textwrap.dedent("""
- class {0.ctypes_typename:}(ctypes.Structure):
- \"\"\"{0.docstring:}\"\"\"
- """.format(self))
- with indent:
- if self.members:
- s += indent("\n_fields_ = [\n")
- with indent:
- with indent:
- s += ",\n".join(indent(m.ctypes_field_decl)
- for m in six.itervalues(self.members))
- s += indent("\n]\n")
- return s
-
- @property
- def ctypes_typename(self):
- """Mangles ctypes.Structure typenames to distinguish them from wrappers."""
- return codegen_util.mangle_struct_typename(self.typename)
-
- @property
- def ctypes_field_decl(self):
- """Generates a declaration for self as a field of a ctypes.Structure."""
- return "('{0.name:}', {0.ctypes_typename:})".format(self) # pylint: disable=missing-format-attribute
-
- @property
- def wrapper_name(self):
- return codegen_util.camel_case(self.typename) + "Wrapper"
-
- @property
- def wrapper_class(self):
- """Generates a Python class containing getter/setter methods for members."""
- indent = codegen_util.Indenter()
- s = textwrap.dedent("""
- class {0.wrapper_name}(util.WrapperBase):
- \"\"\"{0.docstring:}\"\"\"
- """.format(self))
- with indent:
- s += "".join(indent(m.getters_setters)
- for m in six.itervalues(self.members))
- return s
-
- @property
- def getters_setters(self):
- """Populates a Python class with getter & setter methods for self."""
- return textwrap.dedent("""
- @util.CachedProperty
- def {0.name:}(self):
- \"\"\"{0.docstring:}\"\"\"
- return {0.wrapper_name}(ctypes.pointer(self._ptr.contents.{0.name}))
- """.format(self)) # pylint: disable=missing-format-attribute
-
- @property
- def arg(self):
- """String representation of self as a ctypes function argument."""
- return self.ctypes_typename
-
-
-class ScalarPrimitive(CDeclBase):
- """A scalar value corresponding to a C primitive type."""
-
- def __init__(self, name, typename, comment="", parent=None, is_const=None):
- super(ScalarPrimitive, self).__init__(name=name,
- typename=typename,
- comment=comment,
- parent=parent,
- is_const=is_const)
-
- @property
- def ctypes_field_decl(self):
- """Generates a declaration for self as a field of a ctypes.Structure."""
- return "('{0.name:}', {0.ctypes_typename:})".format(self) # pylint: disable=missing-format-attribute
-
- @property
- def getters_setters(self):
- """Populates a Python class with getter & setter methods for self."""
- return textwrap.dedent("""
- @property
- def {0.name:}(self):
- \"\"\"{0.docstring:}\"\"\"
- return self._ptr.contents.{0.name:}
-
- @{0.name:}.setter
- def {0.name:}(self, value):
- self._ptr.contents.{0.name:} = value
- """.format(self)) # pylint: disable=missing-format-attribute
-
- @property
- def arg(self):
- """String representation of self as a ctypes function argument."""
- return self.ctypes_typename
-
-
-class ScalarPrimitivePtr(CDeclBase):
- """Pointer to a ScalarPrimitive."""
-
- def __init__(self, name, typename, comment="", parent=None, is_const=None):
- super(ScalarPrimitivePtr, self).__init__(name=name,
- typename=typename,
- comment=comment,
- parent=parent,
- is_const=is_const)
-
- @property
- def ctypes_field_decl(self):
- """Generates a declaration for self as a field of a ctypes.Structure."""
- return "('{0.name:}', {0.ctypes_ptr:})".format(self) # pylint: disable=missing-format-attribute
-
- @property
- def getters_setters(self):
- """Populates a Python class with getter & setter methods for self."""
- return textwrap.dedent("""
- @property
- def {0.name:}(self):
- \"\"\"{0.docstring:}\"\"\"
- return self._ptr.contents.{0.name:}
-
- @{0.name:}.setter
- def {0.name:}(self, value):
- self._ptr.contents.{0.name:} = value
- """.format(self)) # pylint: disable=missing-format-attribute
-
- @property
- def arg(self):
- """Generates string representation of self as a ctypes function argument."""
- # we assume that every pointer that maps to a numpy dtype corresponds to an
- # array argument/return value
- if self.ctypes_typename in header_parsing.CTYPES_TO_NUMPY:
- return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s:})"
- "".format(self)) # pylint: disable=missing-format-attribute
- else:
- return self.ctypes_ptr
-
-
-class StaticPtrArray(CDeclBase):
- """Array of arbitrary pointers whose size can be inferred from the headers."""
-
- def __init__(self, name, typename, shape, comment="", parent=None,
- is_const=None):
- super(StaticPtrArray, self).__init__(name=name,
- typename=typename,
- shape=shape,
- comment=comment,
- parent=parent,
- is_const=is_const)
-
- @property
- def ctypes_field_decl(self):
- """Generates a declaration for self as a field of a ctypes.Structure."""
- if self.typename in header_parsing.CTYPES_PTRS:
- return "('{0.name:}', {0.ctypes_ptr:} * {1:})".format( # pylint: disable=missing-format-attribute
- self, " * ".join(str(d) for d in self.shape))
- else:
- return "('{0.name:}', {0.ctypes_typename:} * {1:})".format( # pylint: disable=missing-format-attribute
- self, " * ".join(str(d) for d in self.shape))
-
- @property
- def getters_setters(self):
- """Populates a Python class with getter & setter methods for self."""
- return textwrap.dedent("""
- @property
- def {0.name:}(self):
- \"\"\"{0.docstring:}\"\"\"
- return self._ptr.contents.{0.name:}
- """.format(self)) # pylint: disable=missing-format-attribute
-
- @property
- def arg(self):
- """Generates string representation of self as a ctypes function argument."""
- return "{0.ctypes_typename:}".format(self)
-
-
-class StaticNDArray(CDeclBase):
- """Numeric array whose dimensions can all be inferred from the headers."""
-
- def __init__(self, name, typename, shape, comment="", parent=None,
- is_const=None):
- super(StaticNDArray, self).__init__(name=name,
- typename=typename,
- shape=shape,
- comment=comment,
- parent=parent,
- is_const=is_const)
-
- @property
- def ctypes_field_decl(self):
- """Generates a declaration for self as a field of a ctypes.Structure."""
- return "('{0.name:}', {0.ctypes_typename:} * ({1:}))".format( # pylint: disable=missing-format-attribute
- self, " * ".join(str(d) for d in self.shape))
-
- @property
- def getters_setters(self):
- """Populates a Python class with a getter method for self (no setter)."""
- return textwrap.dedent("""
- @util.CachedProperty
- def {0.name:}(self):
- \"\"\"{0.docstring:}\"\"\"
- return util.buf_to_npy(self._ptr.contents.{0.name:}, {0.shape!s:})
- """.format(self)) # pylint: disable=missing-format-attribute
-
- @property
- def arg(self):
- """Generates string representation of self as a ctypes function argument."""
- return ("util.ndptr(shape={0.shape}, dtype={0.np_dtype}, " # pylint: disable=missing-format-attribute
- "flags={0.np_flags!s})".format(self))
-
-
-class DynamicNDArray(CDeclBase):
- """Numeric array where one or more dimensions are determined at runtime."""
-
- def __init__(self, name, typename, shape, comment="", parent=None,
- is_const=None):
- super(DynamicNDArray, self).__init__(name=name,
- typename=typename,
- shape=shape,
- comment=comment,
- parent=parent,
- is_const=is_const)
-
- @property
- def runtime_shape_str(self):
- """String representation of shape tuple at runtime."""
- rs = []
- for d in self.shape:
- # dynamically-sized dimension
- if isinstance(d, str):
- if self.parent and d in self.parent.members:
- rs.append("self.{}".format(d))
- else:
- rs.append("self._model.{}".format(d))
- # static dimension
- else:
- rs.append(str(d))
- return str(tuple(rs)).replace("'", "") # strip quotes from string rep
-
- @property
- def ctypes_field_decl(self):
- """Generates a declaration for self as a field of a ctypes.Structure."""
- return "('{0.name:}', {0.ctypes_ptr})".format(self) # pylint: disable=missing-format-attribute
-
- @property
- def getters_setters(self):
- """Populates a Python class with a getter method for self (no setter)."""
- return textwrap.dedent("""
- @util.CachedProperty
- def {0.name:}(self):
- \"\"\"{0.docstring:}\"\"\"
- return util.buf_to_npy(self._ptr.contents.{0.name:},
- {0.runtime_shape_str:})
- """.format(self)) # pylint: disable=missing-format-attribute
-
- @property
- def arg(self):
- """Generates string representation of self as a ctypes function argument."""
- return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s:})"
- "".format(self)) # pylint: disable=missing-format-attribute
-
-
-class Function(CDeclBase):
- """A function declaration including input type(s) and return type."""
-
- def __init__(self, name, arguments, return_value, comment=""):
- super(Function, self).__init__(name=name,
- arguments=arguments,
- return_value=return_value,
- comment=comment)
-
- def ctypes_func_decl(self, cdll_name):
- """Generates a ctypes function declaration."""
- indent = codegen_util.Indenter()
- # triple-quoted docstring
- s = ("{0:}.{1.name:}.__doc__ = \"\"\"\n{1.docstring:}\"\"\"\n" # pylint: disable=missing-format-attribute
- ).format(cdll_name, self)
- # arguments
- s += "{0:}.{1.name:}.argtypes = [".format(cdll_name, self) # pylint: disable=missing-format-attribute
- if len(self.arguments) > 1:
- s += "\n"
- with indent:
- with indent:
- s += ",\n".join(indent(a.arg) for a in six.itervalues(self.arguments))
- s += "\n"
- else:
- s += ", ".join(indent(a.arg) for a in six.itervalues(self.arguments))
- s += "]\n"
- # return value
- s += "{0:}.{1.name:}.restype = {2:}\n".format( # pylint: disable=missing-format-attribute
- cdll_name, self, self.return_value.arg)
- return s
-
- @property
- def docstring(self):
- """Generates a docstring."""
- indent = codegen_util.Indenter()
- s = "\n".join(textwrap.wrap(self.comment, 80)) + "\n\nArgs:\n"
- with indent:
- for a in six.itervalues(self.arguments):
- s += indent("{a.name:}: {a.arg:}{const:}\n".format(
- a=a, const=(" " if a.is_const else "")))
- s += "Returns:\n"
- with indent:
- s += indent("{0.return_value.arg}{1:}\n".format( # pylint: disable=missing-format-attribute
- self, (" " if self.return_value.is_const else "")))
- return s
-
-
-class StaticStringArray(CDeclBase):
- """A string array of fixed dimensions exported by MuJoCo."""
-
- def __init__(self, name, shape, symbol_name):
- super(StaticStringArray, self).__init__(name=name,
- shape=shape,
- symbol_name=symbol_name)
-
- def ctypes_var_decl(self, cdll_name=""):
- """Generates a ctypes export statement."""
-
- ptr_str = "ctypes.c_char_p"
- for dim in self.shape[::-1]:
- ptr_str = "({0} * {1!s})".format(ptr_str, dim)
-
- return "{0} = {1}.in_dll({2}, {3!r})\n".format(
- self.name, ptr_str, cdll_name, self.symbol_name)
-
-
-class FunctionPtr(CDeclBase):
- """A pointer to an externally defined C function."""
-
- def __init__(self, name, symbol_name, type_name=None):
- super(FunctionPtr, self).__init__(
- name=name, symbol_name=symbol_name, type_name=type_name)
-
- def ctypes_var_decl(self, cdll_name=""):
- """Generates a ctypes export statement."""
-
- return "ctypes.c_void_p.in_dll({0}, {1!r})".format(
- cdll_name, self.symbol_name)
diff --git a/dm_control/autowrap/codegen_util.py b/dm_control/autowrap/codegen_util.py
index 5328319c..81946e55 100644
--- a/dm_control/autowrap/codegen_util.py
+++ b/dm_control/autowrap/codegen_util.py
@@ -15,31 +15,14 @@
"""Misc helper functions needed by autowrap.py."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-import keyword
-import re
-
-# Internal dependencies.
-import six
-from six.moves import builtins
_MJXMACRO_SUFFIX = "_POINTERS"
-_PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist + dir(builtins))
-if not six.PY2:
- _PYTHON_RESERVED_KEYWORDS.add("buffer")
-class Indenter(object):
+class Indenter:
r"""Callable context manager for tracking string indentation levels.
- Args:
- level: The initial indentation level.
- indent_str: The string used to indent each line.
-
Example usage:
```python
@@ -56,6 +39,12 @@ class Indenter(object):
"""
def __init__(self, level=0, indent_str=" "):
+ """Initializes an Indenter.
+
+ Args:
+ level: The initial indentation level.
+ indent_str: The string used to indent each line.
+ """
self.indent_str = indent_str
self.level = level
@@ -80,49 +69,27 @@ class UniqueOrderedDict(collections.OrderedDict):
"""Subclass of `OrderedDict` that enforces the uniqueness of keys."""
def __setitem__(self, k, v):
- if k in self:
+ existing_v = self.get(k)
+ if existing_v is None:
+ super().__setitem__(k, v)
+ elif v != existing_v:
raise ValueError("Key '{}' already exists.".format(k))
- super(UniqueOrderedDict, self).__setitem__(k, v)
def macro_struct_name(name, suffix=None):
"""Converts mjxmacro struct names, e.g. "MJDATA_POINTERS" to "mjdata"."""
+ if name.startswith("MJMODEL_POINTERS"):
+ return "mjmodel"
if suffix is None:
suffix = _MJXMACRO_SUFFIX
- return name[:-len(suffix)].lower()
+ if name.endswith(suffix):
+ return name[:-len(suffix)].lower()
+ return name.lower()
def is_macro_pointer(name):
"""Returns True if the mjxmacro struct name contains pointer sizes."""
- return name.endswith(_MJXMACRO_SUFFIX)
-
-
-def mangle_varname(s):
- """Append underscores to ensure that `s` is not a reserved Python keyword."""
- while s in _PYTHON_RESERVED_KEYWORDS:
- s += "_"
- return s
-
-
-def mangle_struct_typename(s):
- """Strip leading underscores and make uppercase."""
- return s.lstrip("_").upper()
-
-
-def mangle_comment(s):
- """Strip extraneous whitespace, add full-stops at end of each line."""
- if not isinstance(s, str):
- return "\n".join(mangle_comment(line) for line in s)
- elif not s:
- return "."
- else:
- return ".\n".join(" ".join(line.split()) for line in s.splitlines()) + "."
-
-
-def camel_case(s):
- """Convert a snake_case string (maybe with lowerCaseFirst) to CamelCase."""
- tokens = re.sub(r"([A-Z])", r" \1", s.replace("_", " ")).split()
- return "".join(w.title() for w in tokens)
+ return name.endswith(_MJXMACRO_SUFFIX) or name.startswith("MJMODEL_POINTERS")
def try_coerce_to_num(s, try_types=(int, float)):
diff --git a/dm_control/autowrap/header_parsing.py b/dm_control/autowrap/header_parsing.py
index 4851fbc1..b13802d7 100644
--- a/dm_control/autowrap/header_parsing.py
+++ b/dm_control/autowrap/header_parsing.py
@@ -15,32 +15,31 @@
"""pyparsing definitions and helper functions for parsing MuJoCo headers."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
import pyparsing as pp
-import six
+
# NB: Don't enable parser memoization (`pp.ParserElement.enablePackrat()`),
# since this results in a ~6x slowdown.
+
+NONE = "None"
+CTYPES_CHAR = "ctypes.c_char"
+
C_TO_CTYPES = {
# integers
"int": "ctypes.c_int",
"unsigned int": "ctypes.c_uint",
- "char": "ctypes.c_char",
+ "char": CTYPES_CHAR,
"unsigned char": "ctypes.c_ubyte",
"size_t": "ctypes.c_size_t",
# floats
"float": "ctypes.c_float",
"double": "ctypes.c_double",
# pointers
- "void": "None",
+ "void": NONE,
}
-CTYPES_PTRS = {"None": "ctypes.c_void_p",}
+CTYPES_PTRS = {NONE: "ctypes.c_void_p"}
CTYPES_TO_NUMPY = {
# integers
@@ -79,10 +78,24 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
return ifelse
+def _nested_ifn_else(ifn_, pred, else_, endif, match_if_true, match_if_false):
+ """Constructs a parser for (possibly nested) if...(else)...endif blocks."""
+ ifnelse = pp.Forward()
+ ifnelse << pp.Group( # pylint: disable=expression-not-assigned
+ ifn_ +
+ pred("predicate") +
+ pp.ZeroOrMore(match_if_true | ifnelse)("if_false") +
+ pp.Optional(else_ +
+ pp.ZeroOrMore(match_if_false | ifnelse)("if_true")) +
+ endif)
+ return ifnelse
+
+
# Some common string patterns to suppress.
# ------------------------------------------------------------------------------
-(X, LPAREN, RPAREN, LBRACK, RBRACK, LBRACE, RBRACE, SEMI, COMMA, EQUAL, FSLASH,
- BSLASH) = map(pp.Suppress, "X()[]{};,=/\\")
+(LPAREN, RPAREN, LBRACK, RBRACK, LBRACE, RBRACE, SEMI, COMMA, EQUAL, FSLASH,
+ BSLASH) = list(map(pp.Suppress, "()[]{};,=/\\"))
+X = (pp.Keyword("X") | pp.Keyword("XMJV") | pp.Keyword("XNV")).suppress()
EOL = pp.LineEnd().suppress()
# Comments, continuation.
@@ -122,8 +135,7 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
PTR = pp.Literal("*")
EXTERN = pp.Keyword("extern")
-NATIVE_TYPENAME = pp.MatchFirst(
- [pp.Keyword(n) for n in six.iterkeys(C_TO_CTYPES)])
+NATIVE_TYPENAME = pp.MatchFirst([pp.Keyword(n) for n in C_TO_CTYPES.keys()])
# Macros.
# ------------------------------------------------------------------------------
@@ -144,6 +156,12 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
(COMMENT("comment") | EOL))
# e.g. "X( mjtNum*, name_textadr, ntext, 1 )"
+XDIM = pp.delimitedList(
+ (
+ pp.Suppress(pp.Keyword("MJ_M") + LPAREN) +
+ NAME +
+ pp.Suppress(RPAREN)
+ ) | NAME | INT, delim="*", combine=True)
XMEMBER = pp.Group(
X +
LPAREN +
@@ -152,21 +170,23 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
COMMA +
NAME("name") +
COMMA +
- pp.delimitedList((INT | NAME), delim=COMMA)("dims") +
+ pp.delimitedList(XDIM, delim=COMMA)("dims") +
RPAREN)
+XMACRO_LINE = XMEMBER | NAME
XMACRO = pp.Group(
pp.Optional(COMMENT("comment")) +
DEFINE +
NAME("name") +
CONTINUATION +
- pp.delimitedList(XMEMBER, delim=CONTINUATION)("members"))
+ pp.delimitedList(XMACRO_LINE, delim=CONTINUATION)("members"))
# Type/variable declarations.
# ------------------------------------------------------------------------------
TYPEDEF = pp.Keyword("typedef").suppress()
-STRUCT = pp.Keyword("struct").suppress()
+STRUCT = pp.Keyword("struct")
+UNION = pp.Keyword("union")
ENUM = pp.Keyword("enum").suppress()
# e.g. "typedef unsigned char mjtByte; // used for true/false"
@@ -183,11 +203,14 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
UNCOND_DECL = DEF_FLAG | DEF_CONST | TYPE_DECL
# Declarations inside (possibly nested) #if(n)def... #else... #endif... blocks.
-COND_DECL = _nested_if_else(IFDEF, NAME, ELSE, ENDIF, UNCOND_DECL, UNCOND_DECL)
+COND_DECL = _nested_if_else(
+ IFDEF, NAME, ELSE, ENDIF, UNCOND_DECL, UNCOND_DECL
+) | _nested_ifn_else(IFNDEF, NAME, ELSE, ENDIF, UNCOND_DECL, UNCOND_DECL)
# Note: this doesn't work for '#if defined(FLAG)' blocks
# e.g. "mjtNum gravity[3]; // gravitational acceleration"
STRUCT_MEMBER = pp.Group(
+ pp.Optional(STRUCT("struct")) +
(NATIVE_TYPENAME | NAME)("typename") +
pp.Optional(PTR("ptr")) +
NAME("name") +
@@ -195,8 +218,9 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
SEMI +
pp.Optional(COMMENT("comment")))
-STRUCT_DECL = pp.Group(
- STRUCT +
+# Struct declaration within a union (non-nested).
+UNION_STRUCT_DECL = pp.Group(
+ STRUCT("struct") +
pp.Optional(NAME("typename")) +
pp.Optional(COMMENT("comment")) +
LBRACE +
@@ -205,6 +229,17 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
pp.Optional(NAME("name")) +
SEMI)
+ANONYMOUS_UNION_DECL = pp.Group(
+ pp.Optional(MULTILINE_COMMENT("comment")) +
+ UNION("anonymous_union") +
+ LBRACE +
+ pp.OneOrMore(
+ UNION_STRUCT_DECL |
+ STRUCT_MEMBER |
+ COMMENT.suppress())("members") +
+ RBRACE +
+ SEMI)
+
# Multiple (possibly nested) struct declarations.
NESTED_STRUCTS = _nested_scopes(
opening=(STRUCT +
@@ -213,7 +248,9 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
LBRACE),
closing=(RBRACE + pp.Optional(NAME("name")) + SEMI),
body=pp.OneOrMore(
- STRUCT_MEMBER | STRUCT_DECL | COMMENT.suppress())("members"))
+ STRUCT_MEMBER |
+ ANONYMOUS_UNION_DECL |
+ COMMENT.suppress())("members"))
BIT_LSHIFT = INT("bit_lshift_a") + pp.Suppress("<<") + INT("bit_lshift_b")
@@ -236,8 +273,9 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
# Function declarations.
# ------------------------------------------------------------------------------
-MJAPI = pp.Keyword("MJAPI")
+MJAPI = pp.Keyword("MJAPI").suppress()
CONST = pp.Keyword("const")
+VOID = pp.Group(pp.Keyword("void") + ~PTR).suppress()
ARG = pp.Group(
pp.Optional(CONST("is_const")) +
@@ -252,18 +290,35 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
pp.Optional(PTR("ptr")))
FUNCTION_DECL = (
- RET("return_value") +
+ (VOID | RET("return_value")) +
NAME("name") +
LPAREN +
- pp.delimitedList(ARG, delim=COMMA)("arguments") +
+ (VOID | pp.delimitedList(ARG, delim=COMMA)("arguments")) +
RPAREN +
SEMI)
MJAPI_FUNCTION_DECL = pp.Group(
pp.Optional(MULTILINE_COMMENT("comment")) +
+ pp.LineStart() +
MJAPI +
FUNCTION_DECL)
+# e.g.
+# // predicate function: set enable/disable based on item category
+# typedef int (*mjfItemEnable)(int category, void* data);
+FUNCTION_PTR_TYPE_DECL = pp.Group(
+ pp.Optional(MULTILINE_COMMENT("comment")) +
+ TYPEDEF +
+ RET("return_type") +
+ LPAREN +
+ PTR +
+ NAME("typename") +
+ RPAREN +
+ LPAREN +
+ (VOID | pp.delimitedList(ARG, delim=COMMA)("arguments")) +
+ RPAREN +
+ SEMI)
+
# Global variables.
# ------------------------------------------------------------------------------
diff --git a/dm_control/blender/fake_core/bpy.py b/dm_control/blender/fake_core/bpy.py
new file mode 100644
index 00000000..5ea814e3
--- /dev/null
+++ b/dm_control/blender/fake_core/bpy.py
@@ -0,0 +1,229 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Fake Blender bpy module."""
+
+from __future__ import annotations # postponed evaluation of annotations
+
+from typing import Any, Collection, Sequence
+
+from dm_control.blender.fake_core import mathutils
+
+# pylint: disable=invalid-name
+# pylint: disable=missing-class-docstring
+
+
+class WindowManager:
+
+ def progress_begin(self, start: int, end: int):
+ pass
+
+ def progress_update(self, steps_done: int):
+ pass
+
+ def progress_end(self):
+ pass
+
+
+class context:
+
+ @property
+ def window_manager(self) -> WindowManager:
+ return WindowManager()
+
+ @staticmethod
+ def evaluated_depsgraph_get():
+ pass
+
+
+class types:
+
+ class Constraint:
+
+ @property
+ def name(self):
+ pass
+
+ @property
+ def owner_space(self):
+ pass
+
+ class Scene:
+ pass
+
+ class Object:
+
+ @property
+ def name(self) -> str:
+ raise NotImplementedError()
+
+ @property
+ def parent(self) -> types.Object | None:
+ pass
+
+ @property
+ def parent_bone(self) -> types.Bone | None:
+ pass
+
+ @property
+ def data(self):
+ pass
+
+ @property
+ def pose(self):
+ pass
+
+ @property
+ def matrix_world(self) -> mathutils.Matrix:
+ raise NotImplementedError()
+
+ @matrix_world.setter
+ def matrix_world(self, _) -> mathutils.Matrix:
+ raise NotImplementedError()
+
+ def select_set(self, _):
+ pass
+
+ def to_mesh(self):
+ pass
+
+ def evaluated_get(self, _) -> types.Object:
+ pass
+
+ @property
+ def mode(self) -> str:
+ return 'OBJECT'
+
+ @property
+ def type(self):
+ pass
+
+ def update_from_editmode(self):
+ pass
+
+ class Bone:
+
+ @property
+ def name(self) -> str:
+ raise NotImplementedError()
+
+ @property
+ def parent(self) -> types.Bone | None:
+ pass
+
+ @property
+ def matrix_local(self) -> mathutils.Matrix:
+ raise NotImplementedError()
+
+ @property
+ def matrix(self) -> mathutils.Matrix:
+ raise NotImplementedError()
+
+ class bpy_struct:
+ pass
+
+ class Context:
+
+ @property
+ def scene(self) -> types.Scene:
+ pass
+
+ class Light:
+
+ @property
+ def type(self):
+ pass
+
+ @property
+ def use_shadow(self):
+ pass
+
+ @property
+ def color(self) -> mathutils.Color:
+ raise NotImplementedError()
+
+ @property
+ def linear_attenuation(self):
+ pass
+
+ @property
+ def quadratic_attenuation(self):
+ pass
+
+ class LimitRotationConstraint(Constraint):
+ pass
+
+ class LimitLocationConstraint(Constraint):
+ pass
+
+ class Material:
+
+ @property
+ def name(self) -> str:
+ raise NotImplementedError()
+
+ @property
+ def specular_intensity(self):
+ pass
+
+ @property
+ def metallic(self):
+ pass
+
+ @property
+ def roughness(self) -> float:
+ raise NotImplementedError()
+
+ @property
+ def diffuse_color(self) -> Sequence[float]:
+ raise NotImplementedError()
+
+ class Mesh:
+
+ @property
+ def name(self) -> str:
+ raise NotImplementedError()
+
+ def calc_loop_triangles(self):
+ pass
+
+ @property
+ def uv_layers(self) -> Any:
+ raise NotImplementedError()
+
+ @property
+ def loop_triangles(self) -> Collection[Any]:
+ raise NotImplementedError()
+
+ @property
+ def vertices(self) -> Any:
+ raise NotImplementedError()
+
+
+class ops:
+
+ class object:
+
+ @staticmethod
+ def select_all(action):
+ pass
+
+ class export_mesh:
+
+ @classmethod
+ def stl(
+ cls, filepath, use_selection, use_mesh_modifiers, axis_forward, axis_up
+ ):
+ pass
diff --git a/dm_control/blender/fake_core/mathutils.py b/dm_control/blender/fake_core/mathutils.py
new file mode 100644
index 00000000..a48c21e9
--- /dev/null
+++ b/dm_control/blender/fake_core/mathutils.py
@@ -0,0 +1,123 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Fake Blender mathutils module."""
+# pylint: disable=invalid-name
+
+import numpy as np
+
+
+class Color:
+ """Fake color class."""
+
+ def __init__(self, coords):
+ self._coords = coords
+
+ @property
+ def r(self) -> float:
+ return self._coords[0]
+
+ @property
+ def g(self) -> float:
+ return self._coords[1]
+
+ @property
+ def b(self) -> float:
+ return self._coords[2]
+
+ @property
+ def a(self) -> float:
+ return self._coords[3]
+
+
+class Vector:
+ """Fake vector class."""
+
+ def __init__(self, coords):
+ self._coords = np.asarray(coords)
+
+ @property
+ def x(self) -> float:
+ return self._coords[0]
+
+ @property
+ def y(self) -> float:
+ return self._coords[1]
+
+ @property
+ def z(self) -> float:
+ return self._coords[2]
+
+ @property
+ def w(self) -> float:
+ return self._coords[3]
+
+ def __eq__(self, rhs) -> bool:
+ if not isinstance(rhs, Vector):
+ return False
+ return np.linalg.norm(self._coords - rhs._coords) < 1e-6
+
+ def __str__(self) -> str:
+ return 'Vector({:.2f}, {:.2f}, {:.2f})'.format(self.x, self.y, self.z)
+
+ def __repr__(self) -> str:
+ return 'Vector({:.2f}, {:.2f}, {:.2f})'.format(self.x, self.y, self.z)
+
+
+class Quaternion:
+ """Fake quaternion class."""
+
+ def __init__(self, coords):
+ self._coords = coords
+
+ @property
+ def x(self) -> float:
+ return self._coords[1]
+
+ @property
+ def y(self) -> float:
+ return self._coords[2]
+
+ @property
+ def z(self) -> float:
+ return self._coords[3]
+
+ @property
+ def w(self) -> float:
+ return self._coords[0]
+
+ def __matmul__(self, rhs: Vector) -> Vector:
+ return Vector((1, 0, 0))
+
+
+class Matrix:
+ """Fake matrix class."""
+
+ def __init__(self, coords):
+ self._coords = coords
+
+ @classmethod
+ def Diagonal(cls, _):
+ return cls((1,))
+
+ @property
+ def translation(self) -> Vector:
+ return Vector((0.0, 0.0, 0.0))
+
+ def to_quaternion(self) -> Quaternion:
+ return Quaternion((1.0, 0.0, 0.0, 0.0))
+
+ def to_scale(self) -> Vector:
+ return Vector((0.0, 0.0, 0.0))
diff --git a/dm_control/blender/mujoco_exporter/README.md b/dm_control/blender/mujoco_exporter/README.md
new file mode 100644
index 00000000..a1470d9c
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/README.md
@@ -0,0 +1,213 @@
+# Export Mujoco models from Blender
+
+## Prerequisites
+
+The MuJoCo exporter works as Blender addon, tested with ver 3.4.1. You can
+download Blender from [https://www.blender.org](https://www.blender.org/).
+
+## Installation
+
+### Preparing a plugin installer
+
+The `install.sh` command will deploy the installable version of the plugin to
+`./addons/mujoco_model_exporter`. This folder should be zipped:
+
+```shell
+$ cd ./addons && zip -r mujoco_model_exporter.zip mujoco_model_exporter/*
+```
+
+### Installing the plugin in Blender
+
+Open Blender and select `Edit/Preferences` menu option.
+
+{style="display:block;margin:auto"}
+
+Clicking the Install button will open a file selection dialog that will allow
+you to select the .zip archive with the plugin.
+
+The next window will follow, listing the installed plugin and allowing you to
+enable it.
+
+{style="display:block;margin:auto"}
+
+Check the checkbox next to its name to enable it.
+
+## Modelling for Mujoco.
+
+The plugin was designed to allow the artists to keep using core Blender features
+when building models for Mujoco.
+
+feature | Mujoco | Blender
+-------------- | ------------------------- | -----------------------------------
+kinematic tree | hierarchy of Bodies | Armatures or hierarchies of Objects
+geometry | Geoms | Meshes
+materials | Material + Texture assets | Material + Texture definitions
+lighting | Lights | Lights
+
+### Node naming
+
+The exporter will copy the names of the bones, meshes and lights. Mujoco joints
+will be named after the bone that owns the respective constraint, with a postfix
+denoting the degree of freedom they enable.
+
+### Modelling kinematic trees using Armatures and IK
+
+Blender Armatures allow to model kinematic capabilities of a Blender model. They
+comprise a tree of bones, each of which can be further extended with IK
+Constraints.
+
+Only the bones affected by an IK chain will receive appropriate degrees of
+freedom. This means that if you export a model without any IK chains, it will be
+exported as a static model.
+
+#### IK Constraints and bone limits
+
+IK constraints offer the easiest way to verify how a kinematic chain would
+behave in a physical environment. For that reason we chose to use them, along
+with IK bone limits, to model Mujoco joints.
+
+
+
+* Select the bone at the end of a kinematic chain,
+
+
+
+* Add an IK Bone Constraint to it.
+ [This webpage](https://docs.blender.org/manual/en/latest/animation/constraints/tracking/ik_solver.html)
+ contains detailed instructions.
+
+
+
+* Add a target bone to the scene and use it as the IK constraint's target.
+
+
+
+* Adjust the `Chain Length` value to reflect the number of parent bones that
+ should become a part of this IK chain.
+
+
+
+* For each of the bones in the chain, visit the `Bone/Inverse Kinematics` tab
+ and adjust the locks and bone limits.
+
+
+
+*NOTE*: you can take advantage of a helpful limit visualization gizmo.
+
+### Armature free joints.
+
+Armatures form the kinematic trees, so it makes sense to give their roots all 6
+degrees of freedom.
+
+This behavior can be changed by disabling the `Armature freejoint` export
+option.
+
+### Modelling the geometry
+
+Mujoco uses parametric geometry comprised of primitive shapes such
+as cubes, spheres and capsules. It does support triangle based mesh geometry
+however.
+
+Blender on the other hand deals exclusively in meshes, and even though it
+contains a palette of such primitive shapes, these are not parametric.
+
+#### MSH files
+
+The exporter therefore exports the geometry as meshes. All meshes referenced by
+the scene are exported into Mujoco's native .msh format.
+
+When this plugin was created, the native .msh format was the only format to
+support texture mapping. Since MuJoCo 2.1.2, `.obj` files are supported. The
+`.msh` format is expected to be deprecated and removed soon. Until this plugin
+is updated to output `.obj` files we recommend all user to convert their `.msh`
+files to `.obj` files using [this](https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/msh2obj.py) utility.
+
+#### Phong lighting model and lack of support for Cycles nodes
+
+Mujoco implements a forward rendering model with Phong lightning and support for
+reflective surfaces.
+
+[Mujoco material definition](http://www.mujoco.org/book/XMLreference.html#material)
+is fixed and limited to:
+
+* base color (diffuse color)
+* specularity coefficient
+* smoothness coefficient
+* reflectiveness coefficient
+
+These parameters translate well to Blender's default material definition, which
+should be used to model the materials exported to Mujoco.
+
+Due to the diversity and open-ended nature of Cycle's material nodes, the choice
+was made not to support materials defined using them.
+
+#### Division of meshes that employ multiple materials
+
+Mujoco renderer supports a single material per mesh, in contrast to Blender's
+multi material model. In order to support meshes with multiple materials, the
+exporter divides them into submeshes, each of which uses a single material.
+
+Since Blender materials are applied to faces, and a single face may only have a
+single material assigned, there is no risk that the subdivided mesh will exhibit
+any overlapping artifacts.
+
+#### Effect of mesh division on inertia and mass.
+
+The exported submeshes will have different geometry and volume than the
+original. This will affect the mass and inertia of the geom that uses the
+exported mesh.
+
+Please note that Mujoco derives these quantities not from the mesh itself, but
+rather from the convex hull calculated for it. The sum of hulls of subdivided
+meshes are not guaranteed to be equal to the hull of the original mesh.
+
+This phenomenon is illustrated on the image below.
+
+Blender | Mujoco renderer | Mujoco convex hulls
+----------------------------- | ------------------------------- | -------------------
+ |  | 
+
+A different material was assigned to 2 out of 6 rectangular faces of the right
+cube, causing that mesh to be split into two submeshes.
+
+While they render well in Mujoco, the Moiré effect visible on the convex hull
+comes from two overlapping hulls, each with a different volume.
+
+#### Double-sided materials
+
+Double sided materials cause the meshes (or faces) that use them to be exported
+with the faces duplicated with the reverse face winding order. This is because
+Mujoco renderer does not support any other face culling modes than back face
+culling.
+
+*CAUTION*: Because this operation affects the exported geometry, it may
+indirectly affect the physical properties of a geom that references that mesh,
+such as its mass and inertia. Please use this feature with caution.
+
+#### Scaling
+
+*CAUTION*: If you are using the scaling transform, the exporter will modify your
+scene!
+
+The exporter will by default reset the scaling transform on all bones and meshes
+to ensure affine reference frame transformations.
+
+This operation can be undo'ed after the exporter is completed, but the exporter
+doesn't undo it automatically!
+
+### Textures
+
+*NOTE*: This feature will be added in the next version of the exporter.
+
+#### One texture channel
+
+Mujoco renderer supports a single texture channel. It's therefore advised to use
+a single UV map for meshes that are to be exported to Mujoco.
+
+#### Choice between atlases and individual textures
+
+Mujoco's fixed rendering pipeline is quite fast, and handles both variants well.
+Therefore we leave it up to the artist to decide which it prefers using.
+
+Please keep in mind that the exporter doesn't export the texture assets. Those
+should be copied into the folder with the exported model manually.
diff --git a/dm_control/blender/mujoco_exporter/__init__.py b/dm_control/blender/mujoco_exporter/__init__.py
new file mode 100644
index 00000000..a33c6292
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/__init__.py
@@ -0,0 +1,170 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Blender 3.4 plugin for exporting models to MuJoCo native format."""
+
+import contextlib
+import os
+from xml.dom import minidom
+
+import bpy
+from bpy_extras.io_utils import ExportHelper
+
+from . import blender_scene
+from . import mujoco_assets
+from . import mujoco_scene
+
+
+bl_info = {
+ 'name': 'Export MuJoCo',
+ 'author': 'The dm_control authors',
+ 'version': (2, 0),
+ 'blender': (3, 3, 1),
+ 'location': 'File > Export > MuJoCo',
+ 'warning': '',
+ 'description': 'Export articulated MuJoCo model',
+ 'doc_url': '',
+ 'category': 'Import-Export',
+}
+
+
+@contextlib.contextmanager
+def context_settings_cacher(context: bpy.types.Context):
+ """Preserves the pose of exported objects and the scene mode."""
+ # Cache the mode
+ prev_mode = context.mode
+
+ # Set the Object mode required by the exporter
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+ # Set the armatures in their neutral pose
+ pose_positions = []
+ for o in context.scene.objects:
+ if o.type == 'ARMATURE':
+ pose_positions.append((o, o.data.pose_position))
+ o.data.pose_position = 'REST'
+ context.view_layer.update()
+
+ try:
+ yield
+ finally:
+ # Restore the poses.
+ for armature, pose_position in pose_positions:
+ armature.data.pose_position = pose_position
+ context.view_layer.update()
+
+ # Restore the mode
+ bpy.ops.object.mode_set(mode=prev_mode)
+
+
+def apply_scale():
+ bpy.ops.object.select_all(action='SELECT')
+ bpy.ops.object.transform_apply(location=False, scale=True, rotation=False)
+ bpy.ops.object.select_all(action='DESELECT')
+
+
+class ExportMjcf(bpy.types.Operator, ExportHelper):
+ """Export to MJCF file format."""
+
+ bl_idname = 'export_scene.mjcf'
+ bl_label = 'Export MJCF'
+
+ filename_ext = '.xml'
+ filter_glob = bpy.props.StringProperty(default='*.xml', options={'HIDDEN'})
+
+ # Export settings
+ armature_freejoint: bpy.props.BoolProperty(
+ name='Armature freejoint',
+ description='Add a freejoint to armature body',
+ default=False,
+ )
+ apply_mesh_modifiers: bpy.props.BoolProperty(
+ name='Apply modifiers',
+ description='Apply mesh modifiers',
+ default=False,
+ )
+
+ def _export_mjcf(self, context: bpy.types.Context) -> None:
+ """Converts a Blender scene to Mujoco XML format."""
+ # Create a new XML document
+ xml_doc = minidom.Document()
+ mujoco = xml_doc.createElement('mujoco')
+ xml_doc.appendChild(mujoco)
+
+ # Create a list of blender objects, arranged according to their hierarchy,
+ # where parents precede the children.
+ blender_objects = blender_scene.map_blender_tree(
+ context, lambda o: o if o.is_visible else None
+ )
+ # Remove None entries that correspond to invisible objects.
+ blender_objects = [o for o in blender_objects if o]
+
+ export_settings = self.as_keywords()
+
+ # Build the scene tree
+ worldbody_el = mujoco_scene.export_to_xml(
+ doc=xml_doc,
+ objects=blender_objects,
+ armature_freejoint=export_settings['armature_freejoint'],
+ )
+ mujoco.appendChild(worldbody_el)
+
+ # Build the asset tree
+ asset_el = mujoco_assets.export_to_xml(
+ doc=xml_doc,
+ objects=blender_objects,
+ folder=os.path.dirname(export_settings['filepath']),
+ apply_mesh_modifiers=export_settings['apply_mesh_modifiers'],
+ )
+ mujoco.appendChild(asset_el)
+
+ # Add compiler options that would allow to export small feature meshes.
+ compiler_el = xml_doc.createElement('compiler')
+ mujoco.appendChild(compiler_el)
+ compiler_el.setAttribute('boundmass', '1e-3')
+ compiler_el.setAttribute('boundinertia', '1e-9')
+ # TODO(b/266818670): support assets export into subdirectory
+ # compiler_el.setAttribute('meshdir', 'assets')
+
+ # Write the XML to a file.
+ with open(export_settings['filepath'], 'w') as file:
+ file.write(xml_doc.toprettyxml(indent=' '))
+
+ def execute(self, context: bpy.types.Context):
+ """Export the scene."""
+ with context_settings_cacher(context):
+ apply_scale()
+ self._export_mjcf(context)
+
+ return {'FINISHED'}
+
+
+def menu_func_export(self, context: bpy.types.Context):
+ del context
+ self.layout.operator(ExportMjcf.bl_idname, text='MuJoCo (.xml)')
+
+
+def register():
+ bpy.utils.register_class(ExportMjcf)
+ bpy.types.TOPBAR_MT_file_export.append(menu_func_export)
+
+
+def unregister():
+ bpy.utils.unregister_class(ExportMjcf)
+ bpy.types.TOPBAR_MT_file_export.remove(menu_func_export)
+
+
+if __name__ == '__main__':
+ register()
diff --git a/dm_control/blender/mujoco_exporter/blender_scene.py b/dm_control/blender/mujoco_exporter/blender_scene.py
new file mode 100644
index 00000000..3004fdfe
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/blender_scene.py
@@ -0,0 +1,428 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Blender scene parsers."""
+
+# disable: pytype=strict-none
+import dataclasses
+import math
+from typing import Any, Callable, Sequence, Tuple, cast
+
+from dm_control.blender.fake_core import bpy
+from dm_control.blender.fake_core import mathutils
+
+_ARMATURE = 'ARMATURE'
+_CAMERA = 'CAMERA'
+_EMPTY = 'EMPTY'
+_LIGHT = 'LIGHT'
+_MESH = 'MESH'
+_VEC_ZERO = mathutils.Vector((0, 0, 0))
+_OX = mathutils.Vector((1, 0, 0))
+_OY = mathutils.Vector((0, 1, 0))
+_OZ = mathutils.Vector((0, 0, 1))
+
+
+def _check_that_parent_bone_exists(obj: bpy.types.Object) -> None:
+ """Checks if the bone the object was supposed to be parented exists."""
+ assert obj is not None
+ if obj.parent is None:
+ raise ValueError(
+ 'Armature object "{}" was parented to does not exist'.format(obj.name)
+ )
+
+ armature = obj.parent
+ if armature.type != _ARMATURE:
+ raise ValueError(
+ 'Parent of object "{}" - "{}", is not an armature, but a "{}"'.format(
+ obj.name, armature.name, armature.type
+ )
+ )
+
+ if obj.parent_bone not in armature.data.bones:
+ raise ValueError(
+ 'Object "{}" is parented to a non-existing bone "{}" from '
+ 'armature "{}"'.format(obj.name, obj.parent_bone, armature.name)
+ )
+
+
+def _check_constraint_in_local_space(
+ constraint: bpy.types.Constraint, owner: 'ObjectRef'
+) -> None:
+ if constraint and constraint.owner_space != 'LOCAL':
+ raise ValueError(
+ 'Constraint "{}" (bone "{}", armature "{}") uses an unsupported '
+ 'owner_mode "{}". Only "LOCAL" mode is supported at the '
+ 'moment'.format(
+ type(constraint),
+ owner.bone_name(),
+ owner.obj_name(),
+ constraint.owner_space,
+ )
+ )
+
+
+def _angle_distance(lhs_deg: float, rhs_deg: float) -> float:
+ """Calculates distance between two angles, in degrees."""
+ x_2 = math.cos(math.radians(lhs_deg)) * math.cos(math.radians(rhs_deg))
+ y_2 = math.sin(math.radians(lhs_deg)) * math.sin(math.radians(rhs_deg))
+ return math.degrees(math.acos(x_2 + y_2))
+
+
+@dataclasses.dataclass(frozen=True)
+class AffineTransform:
+ pos: mathutils.Vector
+ rot: mathutils.Quaternion
+
+
+@dataclasses.dataclass(frozen=True)
+class Dof:
+ """Degree of freedom description."""
+
+ name: str
+ axis: mathutils.Vector
+ limited: bool = False
+ limits: Tuple[float, float] = (0, 0)
+
+
+@dataclasses.dataclass(frozen=True)
+class ObjectRef:
+ """References a Blender object, be that a scene object or a bone.
+
+ Object reference is hashable and comparable. Equality is based on the names
+ of the underlying objects. Blender's API guarantees that a combination of
+ object/bone name will be unique across the scene. We're leveraging that rule.
+ """
+
+ native_obj: bpy.types.Object | None
+ native_bone: bpy.types.Bone | None = None
+
+ def __hash__(self) -> int:
+ return hash(self.name)
+
+ def __eq__(self, rhs) -> bool:
+ if not isinstance(rhs, ObjectRef):
+ return False
+ return self.name == rhs.name
+
+ @property
+ def is_none(self) -> bool:
+ return self.native_bone is None
+
+ @property
+ def is_armature(self) -> bool:
+ return ( # pytype: disable=bad-return-type
+ self.native_obj
+ and self.native_obj.type == _ARMATURE
+ and not self.native_bone
+ )
+
+ @property
+ def is_bone(self) -> bool:
+ return bool(
+ self.native_obj
+ and self.native_obj.type == _ARMATURE
+ and self.native_bone
+ )
+
+ @property
+ def is_mesh(self) -> bool:
+ return self.native_obj and self.native_obj.type == _MESH # pytype: disable=bad-return-type
+
+ @property
+ def is_light(self) -> bool:
+ return self.native_obj and self.native_obj.type == _LIGHT # pytype: disable=bad-return-type
+
+ @property
+ def is_camera(self) -> bool:
+ return self.native_obj and self.native_obj.type == _CAMERA # pytype: disable=bad-return-type
+
+ @property
+ def is_empty(self) -> bool:
+ return self.native_obj and self.native_obj.type == _EMPTY # pytype: disable=bad-return-type
+
+ @property
+ def name(self) -> str:
+ if not self.native_obj:
+ return ''
+ if self.native_bone:
+ return '{}_{}'.format(self.native_bone.name, self.native_obj.name)
+ else:
+ return self.native_obj.name
+
+ def as_light(self) -> bpy.types.Light:
+ return cast(bpy.types.Light, self.obj_data())
+
+ def get_local_transform(self) -> AffineTransform:
+ """Returns a transform wrt. the local reference frame."""
+
+ def get_bone_local_mtx(bone: bpy.types.Bone) -> mathutils.Matrix:
+ """Derives a local matrix of an armature bone."""
+ assert isinstance(bone, bpy.types.Bone)
+ if bone.parent:
+ return bone.parent.matrix_local.inverted() @ bone.matrix_local
+ else:
+ return bone.matrix_local
+
+ def get_object_local_mtx(obj: bpy.types.Object) -> mathutils.Matrix:
+ """Derives a local matrix of an object, such as a mesh or armature."""
+ assert isinstance(obj, bpy.types.Object)
+ if obj.parent:
+ local_mtx = obj.parent.matrix_world.inverted() @ obj.matrix_world
+ if obj.parent_bone:
+ assert obj.parent.type == _ARMATURE
+ armature = obj.parent
+ bone = armature.data.bones[obj.parent_bone]
+ return bone.matrix_local.inverted() @ local_mtx
+ else:
+ return local_mtx
+ else:
+ return obj.matrix_world
+
+ if self.is_bone:
+ local_mtx = get_bone_local_mtx(self.native_bone)
+ else:
+ local_mtx = get_object_local_mtx(self.native_obj)
+
+ rot_quat = local_mtx.to_quaternion()
+ pos = local_mtx.translation
+ return AffineTransform(pos, rot_quat)
+
+ @property
+ def is_visible(self) -> bool:
+ """Checks if an object is visible."""
+ if not self.native_obj:
+ return False
+ if hasattr(self.native_obj, 'visible_get'):
+ return self.native_obj.visible_get()
+ else:
+ return True
+
+ @property
+ def parent(self) -> 'ObjectRef':
+ """Returns a reference to the parent of this object."""
+ if self.native_obj is None:
+ return ObjectRef(None)
+
+ def bone_parent(
+ armature: bpy.types.Object, bone: bpy.types.Bone
+ ) -> 'ObjectRef':
+ if bone.parent:
+ # Parent of a child bone is its parent bone in the same armature
+ assert isinstance(bone.parent, bpy.types.Bone)
+ return ObjectRef.new_bone(armature, bone.parent)
+ else:
+ # Parent of the root bone is the armature.
+ return ObjectRef.new_object(armature)
+
+ def object_parent(obj: bpy.types.Object) -> 'ObjectRef':
+ if obj.parent_bone:
+ # The object is parented to an armature, and the .parent field must
+ # contain a reference the armature.
+ assert obj.parent and obj.parent.type == _ARMATURE
+ _check_that_parent_bone_exists(obj)
+ armature = obj.parent
+ parent_bone = armature.data.bones[obj.parent_bone]
+ return ObjectRef.new_bone(armature, parent_bone)
+ elif obj.parent:
+ # The object is parented to another object.
+ return ObjectRef.new_object(obj.parent)
+ else:
+ # This is a root object
+ return ObjectRef(None)
+
+ if self.native_bone:
+ return bone_parent(self.native_obj, self.native_bone)
+ else:
+ return object_parent(self.native_obj)
+
+ def get_rotation_dofs(self) -> Sequence[Dof]:
+ """Returns the rotation degrees of freedom in the subtractive mode.
+
+ The method returns the degrees of freedom present from the point of view
+ of Blender. These are modeled such that in absence of constraints, all
+ degrees of freedom are present.
+
+ Returns:
+ A sequence of degree of freedom definitions.
+ """
+ if not self.is_bone:
+ raise ValueError(
+ 'Rotation degrees of freedom are defined only for bones.'
+ )
+ assert self.native_obj
+ assert self.native_obj.type == _ARMATURE
+ armature = self.native_obj
+ bone = armature.pose.bones[self.native_bone.name]
+ if not bone.is_in_ik_chain:
+ # Bones not in an IK chain don't receive any degrees of freedom.
+ return []
+
+ is_locked = [bone.lock_ik_x, bone.lock_ik_y, bone.lock_ik_z]
+ use_limits = [bone.use_ik_limit_x, bone.use_ik_limit_y, bone.use_ik_limit_z]
+ limits = [
+ (math.degrees(bone.ik_min_x), math.degrees(bone.ik_max_x)),
+ (math.degrees(bone.ik_min_y), math.degrees(bone.ik_max_y)),
+ (math.degrees(bone.ik_min_z), math.degrees(bone.ik_max_z)),
+ ]
+ axes = [_OX, _OY, _OZ]
+ names = [
+ 'rx_{}'.format(self.name),
+ 'ry_{}'.format(self.name),
+ 'rz_{}'.format(self.name),
+ ]
+ axis_names = ['X', 'Y', 'Z']
+
+ def build_dof(idx):
+ if limits[idx][0] >= limits[idx][1]:
+ raise ValueError(
+ 'Bone "{}" uses incorrect IK limits for {} axis. '
+ '{} < {} is violated'.format(
+ self.name, axis_names[idx], limits[idx][0], limits[idx][1]
+ )
+ )
+ return Dof(
+ name=names[idx],
+ axis=axes[idx],
+ limited=use_limits[idx],
+ limits=limits[idx],
+ )
+
+ return [build_dof(i) for i in range(3) if not is_locked[i]]
+
+ def obj_data(self) -> bpy.types.Object | None:
+ if self.native_obj is not None:
+ return self.native_obj.data
+ else:
+ return None
+
+ def obj_name(self) -> str | None:
+ if self.native_obj is not None:
+ return self.native_obj.name
+ else:
+ return None
+
+ def bone_name(self) -> str | None:
+ if self.native_bone is not None:
+ return self.native_bone.name
+ else:
+ return None
+
+ @property
+ def mesh(self) -> bpy.types.Mesh:
+ """Returns the mesh associated with this object."""
+ assert self.is_mesh
+ return cast(bpy.types.Mesh, self.obj_data())
+
+ def get_modified_mesh(self) -> bpy.types.Mesh | None:
+ """Returns a mesh with modifiers applied to it."""
+ assert self.is_mesh
+ assert self.native_obj is not None
+ if self.native_obj.mode == 'EDIT':
+ self.native_obj.update_from_editmode()
+
+ # get the modifiers
+ depsgraph = bpy.context.evaluated_depsgraph_get()
+ mesh_owner = self.native_obj.evaluated_get(depsgraph)
+
+ return mesh_owner.to_mesh()
+
+ @property
+ def materials(self) -> Sequence[bpy.types.Material]:
+ """Returns the materials assigned to this object."""
+ data = self.obj_data()
+ if hasattr(data, 'materials'):
+ return data.materials
+ else:
+ return []
+
+ @classmethod
+ def new_object(cls, obj: bpy.types.Object) -> 'ObjectRef':
+ assert isinstance(obj, bpy.types.Object) or obj is None
+ return cls(obj)
+
+ @classmethod
+ def new_bone(
+ cls,
+ armature: bpy.types.Object,
+ bone: bpy.types.Bone | None,
+ ) -> 'ObjectRef':
+ assert isinstance(armature, bpy.types.Object)
+ assert armature.type == _ARMATURE
+ assert isinstance(bone, bpy.types.Bone) or bone is None
+ return cls(armature, bone)
+
+
+NoneRef = ObjectRef(None, None)
+
+
+def map_blender_tree(
+ context: bpy.types.Context, callback: Callable[[ObjectRef], Any]
+) -> Sequence[Any]:
+ """Returns a list of scene objects in the breadth-first order."""
+ # Collect all nodes to explore - objects and bones alike
+ assert context.scene is not None
+ to_explore = [ObjectRef.new_object(o) for o in context.scene.objects]
+ armatures = [o for o in context.scene.objects if o.type == 'ARMATURE']
+ for armature in armatures:
+ for bone in armature.data.bones:
+ to_explore.append(ObjectRef.new_bone(armature, bone))
+
+ explored = set()
+ explored.add(NoneRef)
+ result = []
+
+ while to_explore:
+ obj_ref: ObjectRef = to_explore[0]
+ to_explore = to_explore[1:]
+
+ if obj_ref.parent in explored:
+ explored.add(obj_ref)
+ result.append(callback(obj_ref))
+ else:
+ to_explore.append(obj_ref)
+
+ return result
+
+
+def get_material_mesh_pair_name(mesh_name: str, mat_name: str) -> str:
+ """Build the name for a mesh-material pair."""
+ return '{}_{}'.format(mesh_name, mat_name) if mat_name else mesh_name
+
+
+def is_material_mesh_pair_valid(mesh: bpy.types.Mesh, mat_idx: int) -> bool:
+ """Tests the mesh-material pair whether it contains any geometry."""
+ mesh.calc_loop_triangles()
+ faces = [f for f in mesh.loop_triangles if f.material_index == mat_idx]
+ return bool(faces)
+
+
+def map_materials(
+ func: Callable[[bpy.types.Material], Any],
+ materials: Sequence[bpy.types.Material],
+) -> Sequence[Any]:
+ """Maps a collection of materials, adjusting for empty collections.
+
+ In case of an empty collection, a substitute for a default material is passed
+ to the mapping callback.
+
+ Args:
+ func: Mapping callback.
+ materials: Collection of materials.
+
+ Returns:
+ An arbitrary collection of data mapped out of the materials collection.
+ """
+ materials = materials or []
+ return [func(material) for material in materials]
diff --git a/dm_control/blender/mujoco_exporter/blender_scene_test.py b/dm_control/blender/mujoco_exporter/blender_scene_test.py
new file mode 100644
index 00000000..81ba49f6
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/blender_scene_test.py
@@ -0,0 +1,244 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for blender_scene.py."""
+
+import math
+from unittest import mock
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.blender.fake_core import bpy
+from dm_control.blender.fake_core import mathutils
+from dm_control.blender.mujoco_exporter import blender_scene
+from dm_control.blender.mujoco_exporter import testing
+import numpy as np
+
+
+class AngleDistanceTest(parameterized.TestCase):
+
+ @parameterized.parameters([
+ [0, 0, 0],
+ [90, 90, 0],
+ [-90, -90, 0],
+ [-90, 0, 90],
+ [0, 90, 90],
+ [-90, 90, 180],
+ [-180, 180, 0],
+ [-270, 90, 0],
+ [-90, 270, 0],
+ [-90, 360, 90],
+ ])
+ def test_angle_distance(self, lhs, rhs, result):
+ self.assertAlmostEqual(blender_scene._angle_distance(lhs, rhs), result, 2)
+
+
+class ObjectRefTest(absltest.TestCase):
+
+ def test_hashing_and_comparison(self):
+ mesh_a = blender_scene.ObjectRef.new_object(
+ testing.build_mesh_object('mesh_a')
+ )
+ mesh_a_copy = blender_scene.ObjectRef.new_object(
+ testing.build_mesh_object('mesh_a')
+ )
+ mesh_b = blender_scene.ObjectRef.new_object(
+ testing.build_mesh_object('mesh_b')
+ )
+
+ self.assertEqual(hash(mesh_a), hash(mesh_a_copy))
+ self.assertNotEqual(hash(mesh_a), hash(mesh_b))
+
+ self.assertEqual(mesh_a, mesh_a_copy)
+ self.assertNotEqual(mesh_a, mesh_b)
+
+ def test_parent_of_root_object_is_noneref(self):
+ obj_ref = blender_scene.ObjectRef.new_object(
+ testing.build_mesh_object('mesh')
+ )
+ self.assertEqual(obj_ref.parent, blender_scene.NoneRef)
+
+ def test_parent_of_object_parented_to_another_object(self):
+ parent_obj = testing.build_mesh_object('parent_obj')
+ child_obj = testing.build_mesh_object('child_obj', parent=parent_obj)
+
+ parent_ref = blender_scene.ObjectRef.new_object(parent_obj)
+ child_ref = blender_scene.ObjectRef.new_object(child_obj)
+ self.assertEqual(child_ref.parent, parent_ref)
+
+ def test_parent_of_object_parented_to_bone(self):
+ bone = testing.build_bone('bone')
+ armature = testing.build_armature('armature', bones=[bone])
+ child_obj = testing.build_mesh_object(
+ 'child_obj', parent=armature, parent_bone=bone.name
+ )
+
+ parent_ref = blender_scene.ObjectRef.new_bone(armature, bone)
+ child_ref = blender_scene.ObjectRef.new_object(child_obj)
+ self.assertEqual(child_ref.parent, parent_ref)
+
+ def test_parent_of_root_bone_is_armature(self):
+ bone = testing.build_bone('bone')
+ armature = testing.build_armature('armature', bones=[bone])
+
+ armature_ref = blender_scene.ObjectRef.new_object(armature)
+ bone_ref = blender_scene.ObjectRef.new_bone(armature, bone)
+ self.assertEqual(bone_ref.parent, armature_ref)
+
+
+class BlenderTreeMapperTest(parameterized.TestCase):
+
+ @parameterized.parameters([
+ [['obj1']],
+ [['obj1', 'obj2']],
+ [['obj2', 'obj1']],
+ [['obj2', 'obj1', 'obj3']],
+ ])
+ def test_object_without_hierarchy(self, object_names):
+ context = mock.MagicMock(spec=bpy.types.Context)
+ context.scene.objects = [
+ testing.build_mesh_object(name) for name in object_names
+ ]
+
+ objects = blender_scene.map_blender_tree(context, lambda x: x)
+ self.assertEqual([o.native_obj.name for o in objects], object_names)
+
+ @parameterized.parameters([
+ # A hierarchy with 2 objects
+ dict(
+ parent_child_tuples=[('obj_1', 'obj_1_1')],
+ expected_order=['obj_1', 'obj_1_1'],
+ ),
+ # Shallow hierarchy with 3 objects
+ dict(
+ parent_child_tuples=[('obj_1', 'obj_1_1'), ('obj_1', 'obj_1_2')],
+ expected_order=['obj_1', 'obj_1_1', 'obj_1_2'],
+ ),
+ # Deep hierarchy with 3 objects
+ dict(
+ parent_child_tuples=[('obj_1', 'obj_1_1'), ('obj_1_1', 'obj_1_2')],
+ expected_order=['obj_1', 'obj_1_1', 'obj_1_2'],
+ ),
+ # 2 hierarchies, 2 objects each
+ dict(
+ parent_child_tuples=[('obj_1', 'obj_1_1'), ('obj_2', 'obj_2_1')],
+ expected_order=['obj_1', 'obj_1_1', 'obj_2', 'obj_2_1'],
+ ),
+ ])
+ def test_object_with_hierarchy(self, parent_child_tuples, expected_order):
+ objects = {}
+ for parent_name, child_name in parent_child_tuples:
+ parent = objects.get(parent_name, None)
+ if not parent:
+ parent = testing.build_mesh_object(parent_name)
+ objects[parent_name] = parent
+
+ child = testing.build_mesh_object(child_name, parent=parent)
+ objects[child_name] = child
+
+ context = mock.MagicMock(spec=bpy.types.Context)
+ context.scene.objects = [o for o in objects.values()]
+
+ object_names = blender_scene.map_blender_tree(context, lambda x: x.name)
+ self.assertEqual(object_names, expected_order)
+
+ def test_boneless_armature(self):
+ armature = testing.build_armature('armature')
+
+ context = mock.MagicMock(spec=bpy.types.Context)
+ context.scene.objects = [armature]
+
+ objects = blender_scene.map_blender_tree(context, lambda x: x)
+ self.assertLen(objects, 1)
+
+ def test_armature_with_bones(self):
+ root = testing.build_bone('root')
+ r_arm = testing.build_bone('r_arm', parent=root)
+ l_arm = testing.build_bone('l_arm', parent=root)
+ r_finger = testing.build_bone('r_finger', parent=r_arm)
+ armature = testing.build_armature(
+ 'armature', bones=[r_finger, l_arm, root, r_arm]
+ )
+ expected_order = [
+ 'armature',
+ 'root_armature',
+ 'r_arm_armature',
+ 'r_finger_armature',
+ 'l_arm_armature',
+ ]
+
+ context = mock.MagicMock(spec=bpy.types.Context)
+ # Randomize the order of objects to test the sorting mechanism.
+ context.scene.objects = [armature]
+
+ object_names = blender_scene.map_blender_tree(context, lambda x: x.name)
+ self.assertEqual(object_names, expected_order)
+
+
+class DegreesOfFreedomTest(parameterized.TestCase):
+
+ def test_ik_with_unlimited_dofs_adds_all_dofs(self):
+ constraint = testing.build_rotation_constraint(
+ use_limit_x=False, use_limit_y=False, use_limit_z=False
+ )
+ bone = testing.build_bone('bone', constraint=constraint)
+ armature = testing.build_armature('bone', bones=[bone])
+ bone_ref = blender_scene.ObjectRef.new_bone(armature, bone)
+
+ dofs = bone_ref.get_rotation_dofs()
+ self.assertLen(dofs, 3)
+ self.assertEqual(dofs[0].axis, mathutils.Vector((1, 0, 0)))
+ self.assertEqual(dofs[1].axis, mathutils.Vector((0, 1, 0)))
+ self.assertEqual(dofs[2].axis, mathutils.Vector((0, 0, 1)))
+ self.assertFalse(dofs[0].limited)
+ self.assertFalse(dofs[1].limited)
+ self.assertFalse(dofs[2].limited)
+
+ def test_ik_with_removed_dofs(self):
+ constraint = testing.build_rotation_constraint(
+ lock_x=False,
+ lock_y=True,
+ lock_z=False,
+ use_limit_x=True,
+ use_limit_z=False,
+ )
+ bone = testing.build_bone('bone', constraint=constraint)
+ armature = testing.build_armature('bone', bones=[bone])
+ bone_ref = blender_scene.ObjectRef.new_bone(armature, bone)
+
+ dofs = bone_ref.get_rotation_dofs()
+ self.assertLen(dofs, 2)
+ self.assertEqual(dofs[0].axis, mathutils.Vector((1, 0, 0)))
+ self.assertEqual(dofs[1].axis, mathutils.Vector((0, 0, 1)))
+ self.assertTrue(dofs[0].limited)
+ self.assertFalse(dofs[1].limited)
+
+ def test_limiting_dof(self):
+ constraint = testing.build_rotation_constraint(
+ use_limit_x=True, min_x=math.radians(-15), max_x=math.radians(30)
+ )
+ bone = testing.build_bone('bone', constraint=constraint)
+ armature = testing.build_armature('bone', bones=[bone])
+ bone_ref = blender_scene.ObjectRef.new_bone(armature, bone)
+
+ dofs = bone_ref.get_rotation_dofs()
+ self.assertLen(dofs, 3)
+ self.assertTrue(dofs[0].limited)
+ self.assertFalse(dofs[1].limited)
+ self.assertFalse(dofs[2].limited)
+ np.testing.assert_array_almost_equal(dofs[0].limits, (-15, 30), 1e-2)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/blender/mujoco_exporter/doc/add_ik_constraint.png b/dm_control/blender/mujoco_exporter/doc/add_ik_constraint.png
new file mode 100644
index 00000000..9b86681a
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/add_ik_constraint.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/convhull_blend.png b/dm_control/blender/mujoco_exporter/doc/convhull_blend.png
new file mode 100644
index 00000000..c1a56626
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/convhull_blend.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/convhull_mj_hull.png b/dm_control/blender/mujoco_exporter/doc/convhull_mj_hull.png
new file mode 100644
index 00000000..fb10b657
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/convhull_mj_hull.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/convhull_mj_rend.png b/dm_control/blender/mujoco_exporter/doc/convhull_mj_rend.png
new file mode 100644
index 00000000..094a8aa9
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/convhull_mj_rend.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/enable_plugin.png b/dm_control/blender/mujoco_exporter/doc/enable_plugin.png
new file mode 100644
index 00000000..f82c349a
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/enable_plugin.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/end_effector_selection.png b/dm_control/blender/mujoco_exporter/doc/end_effector_selection.png
new file mode 100644
index 00000000..4db770a7
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/end_effector_selection.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/gizmo.png b/dm_control/blender/mujoco_exporter/doc/gizmo.png
new file mode 100644
index 00000000..fe9d8762
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/gizmo.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/ik_chain_length.png b/dm_control/blender/mujoco_exporter/doc/ik_chain_length.png
new file mode 100644
index 00000000..3629c71d
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/ik_chain_length.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/ik_target.png b/dm_control/blender/mujoco_exporter/doc/ik_target.png
new file mode 100644
index 00000000..c73d5c36
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/ik_target.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/install_plugin.png b/dm_control/blender/mujoco_exporter/doc/install_plugin.png
new file mode 100644
index 00000000..9a66ebc7
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/install_plugin.png differ
diff --git a/dm_control/blender/mujoco_exporter/doc/limits.png b/dm_control/blender/mujoco_exporter/doc/limits.png
new file mode 100644
index 00000000..f41f2449
Binary files /dev/null and b/dm_control/blender/mujoco_exporter/doc/limits.png differ
diff --git a/dm_control/blender/mujoco_exporter/install.sh b/dm_control/blender/mujoco_exporter/install.sh
new file mode 100755
index 00000000..0c71fea8
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/install.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+set -euo pipefail
+
+export output_dir="./addons/mujoco_model_exporter"
+rm -rf "${output_dir}"
+mkdir -p "${output_dir}"
+
+cp ./*.py "${output_dir}"
+find "${output_dir}" -name "*.py" -exec sed -i "s/from dm_control.blender.fake_core //g" "{}" +;
+find "${output_dir}" -name "*.py" -exec sed -i "s/from dm_control.blender.mujoco_exporter/from ./g" "{}" +;
+echo "Add-on exported to ${output_dir}."
+echo "Copy this to Blender ./scripts/addons - see https://docs.blender.org/manual/en/latest/advanced/blender_directory_layout.html"
diff --git a/dm_control/blender/mujoco_exporter/mujoco_assets.py b/dm_control/blender/mujoco_exporter/mujoco_assets.py
new file mode 100644
index 00000000..86b45d20
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/mujoco_assets.py
@@ -0,0 +1,175 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Mujoco asset exporters.
+
+Note about Materials.
+Material nodes are not supported, so please use basic Blender materials.
+"""
+
+import os
+import struct
+from typing import Sequence
+from xml.dom import minidom
+
+from dm_control.blender.fake_core import bpy
+from dm_control.blender.mujoco_exporter import blender_scene
+
+_MESH = 'MESH'
+
+
+class MujocoMesh:
+ """Mujoco representation of a mesh.
+
+ Mujoco can't handle meshes that reference multiple materials. Therefore
+ an instance of this class represents a submesh that uses a specific material.
+ """
+
+ def __init__(self, mesh: bpy.types.Mesh, material_idx: int, two_sided: bool):
+ if mesh.uv_layers.active is None:
+ raise ValueError(f'Mesh {mesh.name} does not have an active UV layer.')
+
+ mesh.calc_loop_triangles()
+
+ indices = []
+ for face in mesh.loop_triangles:
+ if face.material_index == material_idx:
+ indices.extend([i for i in face.vertices])
+
+ uv_layer = mesh.uv_layers.active.data
+ # pylint: disable=g-complex-comprehension
+ self.vertices = [c for i in indices for c in mesh.vertices[i].co]
+ self.normals = [c for i in indices for c in mesh.vertices[i].normal]
+ self.uvs = [c for i in indices for c in uv_layer[i].uv]
+ # pylint: enable=g-complex-comprehension
+ self.faces = list(range(len(self.vertices) // 3))
+
+ if two_sided:
+ # For two-sided meshes, duplicate the geometry with flipped normals
+ # and reversed triangle winding order.
+ base_vertex = len(self.vertices) // 3
+ self.vertices += list(self.vertices)
+ self.normals += [n * -1.0 for n in self.normals]
+ self.uvs += list(self.uvs)
+ self.faces += [i + base_vertex for i in reversed(self.faces)]
+
+ def save(self, filepath: str) -> None:
+ """Save the data in the MSH format."""
+ nvertex = len(self.vertices)
+ nnormal = len(self.normals)
+ ntexcoord = len(self.uvs)
+ nface = len(self.faces)
+ assert nvertex % 3 == 0
+ assert nnormal % 3 == 0
+ assert ntexcoord % 2 == 0
+ assert nface % 3 == 0
+
+ fmt_msh = '4i{}f{}f{}f{}i'.format(nvertex, nnormal, ntexcoord, nface)
+ with open(filepath, 'wb') as f:
+ f.write(
+ struct.pack(
+ fmt_msh,
+ nvertex // 3,
+ nnormal // 3,
+ ntexcoord // 2,
+ nface // 3,
+ *self.vertices,
+ *self.normals,
+ *self.uvs,
+ *self.faces,
+ )
+ )
+
+
+def mesh_asset_builder(
+ doc: minidom.Document,
+ mesh: bpy.types.Mesh,
+ materials: Sequence[bpy.types.Material],
+ folder: str,
+) -> Sequence[minidom.Element]:
+ """Exports a mesh associated with the object and creates an 'asset' node."""
+ mat_names = blender_scene.map_materials(lambda m: m.name, materials)
+ twosidedness = blender_scene.map_materials(
+ lambda m: not m.use_backface_culling, materials
+ )
+ mat_names = mat_names or ['']
+ twosidedness = twosidedness or [False]
+
+ elements = []
+ for mat_idx, (mat_name, twosided) in enumerate(zip(mat_names, twosidedness)):
+ el_name = blender_scene.get_material_mesh_pair_name(mesh.name, mat_name)
+ filename = '{}.msh'.format(el_name)
+ filepath = os.path.join(folder, filename)
+
+ if blender_scene.is_material_mesh_pair_valid(mesh, mat_idx):
+ mesh_to_export = MujocoMesh(mesh, mat_idx, twosided)
+ mesh_to_export.save(filepath)
+
+ el = doc.createElement('mesh')
+ el.setAttribute('name', el_name)
+ el.setAttribute('file', filename)
+ elements.append(el)
+ return elements
+
+
+def clip01(value):
+ return min(1.0, max(0.0, value))
+
+
+def material_asset_builder(
+ doc: minidom.Document, mat: bpy.types.Material
+) -> minidom.Element:
+ """Builds a material asset node."""
+ el = doc.createElement('material')
+ el.setAttribute('name', mat.name)
+ # el.setAttribute('texture', '')
+ el.setAttribute('specular', str(clip01(mat.specular_intensity)))
+ el.setAttribute('shininess', str(clip01(1.0 - mat.roughness)))
+ el.setAttribute('reflectance', str(clip01(mat.metallic)))
+ el.setAttribute('rgba', ' '.join([str(c) for c in mat.diffuse_color]))
+ return el
+
+
+def export_to_xml(
+ doc: minidom.Document,
+ objects: Sequence[blender_scene.ObjectRef],
+ folder: str,
+ apply_mesh_modifiers: bool,
+) -> minidom.Element:
+ """converts Blender scene objects to Mujoco assets."""
+ asset_el = doc.createElement('asset')
+
+ unique = set()
+ for obj in objects:
+ if obj.is_mesh and obj.mesh not in unique:
+ # Use the base mesh object as a reference, because that's the reference
+ # that will be shared between objects.
+ unique.add(obj.mesh)
+ if apply_mesh_modifiers:
+ # Make sure to export the mesh with modifiers applied, and if that's not
+ # possible (it may not be possible and None will be returned), default
+ # to the base mesh.
+ mesh_for_export = obj.get_modified_mesh() or obj.mesh
+ else:
+ mesh_for_export = obj.mesh
+ for el in mesh_asset_builder(doc, mesh_for_export, obj.materials, folder):
+ asset_el.appendChild(el)
+
+ for material in obj.materials:
+ if material not in unique:
+ unique.add(material)
+ asset_el.appendChild(material_asset_builder(doc, material))
+
+ return asset_el
diff --git a/dm_control/blender/mujoco_exporter/mujoco_assets_test.py b/dm_control/blender/mujoco_exporter/mujoco_assets_test.py
new file mode 100644
index 00000000..bbf977d3
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/mujoco_assets_test.py
@@ -0,0 +1,225 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for mujoco_assets.py."""
+
+import collections
+from unittest import mock
+from xml.dom import minidom
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.blender.mujoco_exporter import blender_scene
+from dm_control.blender.mujoco_exporter import mujoco_assets
+from dm_control.blender.mujoco_exporter import testing
+
+_DEFAULT_SETTINGS = dict(folder='', apply_mesh_modifiers=False)
+
+
+class MujocoAssetsTest(parameterized.TestCase):
+
+ def test_building_mesh_asset(self):
+ mesh_object = testing.build_mesh_object('mesh')
+
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_root.createElement = mock.MagicMock(spec=minidom.Element)
+
+ obj = blender_scene.ObjectRef(mesh_object)
+ with mock.patch(mujoco_assets.__name__ + '.open', mock.mock_open()):
+ element = mujoco_assets.export_to_xml(
+ xml_root, [obj], **_DEFAULT_SETTINGS
+ )
+
+ xml_root.createElement.assert_any_call('mesh')
+ element.setAttribute.assert_any_call('name', 'mesh')
+ # Note there's no folder in the name.
+ # The asset is assumed to be exported into the same folder as the .xml file.
+ element.setAttribute.assert_any_call('file', 'mesh.msh')
+
+ def test_exporting_mesh_asset(self):
+ mesh = testing.build_mesh(name='mesh.001')
+ mesh_object = testing.build_mesh_object('mesh', mesh=mesh)
+ xml_root = mock.MagicMock(spec=minidom.Document)
+
+ obj = blender_scene.ObjectRef(mesh_object)
+ mock_open = mock.mock_open()
+ with mock.patch(mujoco_assets.__name__ + '.open', mock_open):
+ mujoco_assets.export_to_xml(xml_root, [obj], '/folder', False)
+ mock_open.assert_called_once_with('/folder/mesh.001.msh', 'wb')
+
+ def test_exporting_material_asset(self):
+ material = testing.build_material(
+ 'mat', color=(0.1, 0.2, 0.3, 0.4), specular=0.5, metallic=0.6
+ )
+ xml_root = mock.MagicMock(spec=minidom.Document)
+
+ result = mujoco_assets.material_asset_builder(xml_root, material)
+ result.setAttribute.assert_any_call('name', 'mat')
+ result.setAttribute.assert_any_call('specular', '0.5')
+ result.setAttribute.assert_any_call('reflectance', '0.6')
+ result.setAttribute.assert_any_call('rgba', '0.1 0.2 0.3 0.4')
+
+ @parameterized.parameters([
+ [0.0, 1.0],
+ [0.25, 0.75],
+ [0.5, 0.5],
+ [1.0, 0.0],
+ [5.0, 0.0],
+ [-5.0, 1.0],
+ ])
+ def test_shininess_is_inverse_of_blender_mat_roughness(self, blender, mujoco):
+ material = testing.build_material('mat', roughness=blender)
+ xml_root = mock.MagicMock(spec=minidom.Document)
+
+ result = mujoco_assets.material_asset_builder(xml_root, material)
+ result.setAttribute.assert_any_call('shininess', str(mujoco))
+
+ def test_export_unique_asset_instances(self):
+ num_objects = 10
+ shared_material = testing.build_material('mat_1')
+ shared_mesh = testing.build_mesh('mesh.001', materials=[shared_material])
+
+ objects = [
+ blender_scene.ObjectRef(
+ testing.build_mesh_object('mesh', mesh=shared_mesh)
+ )
+ for _ in range(num_objects)
+ ]
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ with mock.patch(mujoco_assets.__name__ + '.open', mock.mock_open()):
+ mujoco_assets.export_to_xml(xml_root, objects, **_DEFAULT_SETTINGS)
+
+ create_element_call_counts = collections.Counter(
+ [m_[0][0] for m_ in xml_root.createElement.call_args_list]
+ )
+ self.assertEqual(create_element_call_counts['mesh'], 1)
+ self.assertEqual(create_element_call_counts['material'], 1)
+
+ def test_mesh_with_multiple_materials_split(self):
+ mat_1 = testing.build_material('mat_1')
+ mat_2 = testing.build_material('mat_2')
+ mesh = testing.build_mesh(
+ 'mesh',
+ materials=[mat_1, mat_2],
+ faces=[[0, 1, 2], [3, 4, 5]],
+ vert_co=[
+ [0.1, 0.1, 0.1],
+ [0.2, 0.2, 0.2],
+ [0.3, 0.3, 0.3],
+ [0.4, 0.4, 0.4],
+ [0.5, 0.5, 0.5],
+ [0.6, 0.6, 0.6],
+ ],
+ mat_ids=[0, 1],
+ )
+
+ mesh_obj = blender_scene.ObjectRef(
+ testing.build_mesh_object('mesh', mesh=mesh)
+ )
+
+ with mock.patch(mujoco_assets.__name__ + '.open', mock.mock_open()):
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_elem = mock.MagicMock(spec=minidom.Element)
+ xml_root.createElement.return_value = xml_elem
+ mujoco_assets.export_to_xml(xml_root, [mesh_obj], **_DEFAULT_SETTINGS)
+
+ # Verify that two mesh elements were created...
+ create_element_call_counts = collections.Counter(
+ [m_[0][0] for m_ in xml_root.createElement.call_args_list]
+ )
+ self.assertEqual(create_element_call_counts['mesh'], 2)
+
+ # ...and that their names reflect the mesh-material pairing.
+ xml_elem.setAttribute.assert_any_call('file', 'mesh_mat_1.msh')
+ xml_elem.setAttribute.assert_any_call('file', 'mesh_mat_2.msh')
+ xml_elem.setAttribute.assert_any_call('name', 'mesh_mat_1')
+ xml_elem.setAttribute.assert_any_call('name', 'mesh_mat_2')
+
+
+class MujocoMeshTests(parameterized.TestCase):
+
+ def test_splitting_mesh_with_multiple_materials_into_submeshes(self):
+ mat_0 = testing.build_material('mat_0')
+ mat_1 = testing.build_material('mat_1')
+ mesh = testing.build_mesh(
+ 'mesh',
+ materials=[mat_0, mat_1],
+ faces=[[0, 1, 2], [3, 4, 5]],
+ vert_co=[
+ [0.1, 0.1, 0.1],
+ [0.2, 0.2, 0.2],
+ [0.3, 0.3, 0.3],
+ [0.4, 0.4, 0.4],
+ [0.5, 0.5, 0.5],
+ [0.6, 0.6, 0.6],
+ ],
+ mat_ids=[0, 1],
+ )
+
+ mesh_mat_0 = mujoco_assets.MujocoMesh(mesh, material_idx=0, two_sided=False)
+ mesh_mat_1 = mujoco_assets.MujocoMesh(mesh, material_idx=1, two_sided=False)
+
+ self.assertEqual(
+ mesh_mat_0.vertices, [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3]
+ )
+ self.assertEqual(
+ mesh_mat_1.vertices, [0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6]
+ )
+ self.assertEqual(mesh_mat_0.faces, [0, 1, 2])
+ self.assertEqual(mesh_mat_1.faces, [0, 1, 2])
+
+ def test_one_sided_mesh(self):
+ mesh = testing.build_mesh(
+ 'mesh',
+ materials=[testing.build_material('mat')],
+ faces=[[0, 1, 2]],
+ vert_co=[[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]],
+ normals=[[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]],
+ mat_ids=[0],
+ )
+
+ mesh = mujoco_assets.MujocoMesh(mesh, material_idx=0, two_sided=False)
+ self.assertEqual(mesh.faces, [0, 1, 2])
+ self.assertEqual(
+ mesh.vertices, [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3]
+ )
+ self.assertEqual(
+ mesh.normals, [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3]
+ )
+
+ def test_two_sided_mesh(self):
+ mesh = testing.build_mesh(
+ 'mesh',
+ materials=[testing.build_material('mat')],
+ faces=[[0, 1, 2]],
+ vert_co=[[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]],
+ normals=[[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]],
+ mat_ids=[0],
+ )
+
+ mesh = mujoco_assets.MujocoMesh(mesh, material_idx=0, two_sided=True)
+ self.assertEqual(mesh.faces, [0, 1, 2, 5, 4, 3])
+ self.assertEqual(
+ mesh.vertices,
+ [.1, .1, .1, .2, .2, .2, .3, .3, .3,
+ .1, .1, .1, .2, .2, .2, .3, .3, .3,])
+ self.assertEqual(
+ mesh.normals,
+ [.1, .1, .1, .2, .2, .2, .3, .3, .3,
+ -.1, -.1, -.1, -.2, -.2, -.2, -.3, -.3, -.3,])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/blender/mujoco_exporter/mujoco_scene.py b/dm_control/blender/mujoco_exporter/mujoco_scene.py
new file mode 100644
index 00000000..c2d75549
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/mujoco_scene.py
@@ -0,0 +1,188 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Mujoco scene element builders and utilities."""
+
+from typing import Sequence
+from xml.dom import minidom
+
+from dm_control.blender.fake_core import mathutils
+from dm_control.blender.mujoco_exporter import blender_scene
+
+_ARMATURE = 'ARMATURE'
+_MESH = 'MESH'
+_VEC_ZERO = mathutils.Vector((0, 0, 0))
+_OZ = mathutils.Vector((0, 0, 1))
+
+
+def color_to_mjcf(color: mathutils.Color) -> str:
+ return '{} {} {}'.format(color.r, color.g, color.b)
+
+
+def vec_to_mjcf(vec: mathutils.Vector) -> str:
+ return '{} {} {}'.format(vec.x, vec.y, vec.z)
+
+
+def quat_to_mjcf(quat: mathutils.Quaternion) -> str:
+ return '{} {} {} {}'.format(quat.w, quat.x, quat.y, quat.z)
+
+
+def bool_to_mjcf(bool_val: bool):
+ return 'true' if bool_val else 'false'
+
+
+def body_builder(
+ doc: minidom.Document, blender_obj: blender_scene.ObjectRef
+) -> minidom.Element:
+ """Builds a mujoco body element."""
+ transform = blender_obj.get_local_transform()
+
+ el = doc.createElement('body')
+ el.setAttribute('name', blender_obj.name)
+ el.setAttribute('pos', vec_to_mjcf(transform.pos))
+ el.setAttribute('quat', quat_to_mjcf(transform.rot))
+ return el
+
+
+def light_builder(
+ doc: minidom.Document, light_obj: blender_scene.ObjectRef
+) -> minidom.Element:
+ """Builds an mjcf element that describes a light."""
+ assert light_obj.is_light
+ light = light_obj.as_light()
+
+ directional = bool_to_mjcf(light.type == 'SUN' or light.type == 'SPOT')
+ attenuation = '0 {} {}'.format(
+ light.linear_attenuation, light.quadratic_attenuation
+ )
+ transform = light_obj.get_local_transform()
+
+ el = doc.createElement('light')
+ el.setAttribute('name', light_obj.name)
+ el.setAttribute('pos', vec_to_mjcf(transform.pos))
+ el.setAttribute('dir', vec_to_mjcf(transform.rot @ _OZ))
+ el.setAttribute('directional', directional)
+ el.setAttribute('castshadow', bool_to_mjcf(light.use_shadow)) # pytype: disable=wrong-arg-types
+ el.setAttribute('diffuse', color_to_mjcf(light.color))
+ el.setAttribute('attenuation', attenuation)
+ return el
+
+
+def mesh_geom_builder(
+ doc: minidom.Document, mesh_obj: blender_scene.ObjectRef
+) -> Sequence[minidom.Element]:
+ """Builds a mujoco node for a mesh geom."""
+ mesh = mesh_obj.mesh
+ transform = mesh_obj.get_local_transform()
+
+ elements = []
+ mat_names = blender_scene.map_materials(lambda m: m.name, mesh_obj.materials)
+ # It might be the case that the mesh doesn't have any materials assigned.
+ # We still want to export such geom, but we don't want to make a reference to
+ # the material in the node.
+ # So in that case we're using a fake material without a name
+ mat_names = mat_names or ['']
+ for mat_idx, mat_name in enumerate(mat_names):
+ if not blender_scene.is_material_mesh_pair_valid(mesh, mat_idx):
+ continue
+
+ obj_name = blender_scene.get_material_mesh_pair_name(
+ mesh_obj.name, mat_name
+ )
+ mesh_name = blender_scene.get_material_mesh_pair_name(mesh.name, mat_name)
+
+ el = doc.createElement('geom')
+ el.setAttribute('name', obj_name)
+ el.setAttribute('mesh', mesh_name)
+ el.setAttribute('pos', vec_to_mjcf(transform.pos))
+ el.setAttribute('quat', quat_to_mjcf(transform.rot))
+ el.setAttribute('type', 'mesh')
+ if mat_name:
+ el.setAttribute('material', mat_name)
+
+ elements.append(el)
+
+ return elements
+
+
+def joint_builder(
+ doc: minidom.Document,
+ dof: blender_scene.Dof,
+ dof_type: str,
+) -> minidom.Element:
+ """Builds a mujoco hinge definition."""
+ el = doc.createElement('joint')
+ el.setAttribute('name', dof.name)
+ el.setAttribute('type', dof_type)
+ el.setAttribute('limited', bool_to_mjcf(dof.limited))
+ el.setAttribute('pos', vec_to_mjcf(_VEC_ZERO))
+ el.setAttribute('axis', vec_to_mjcf(dof.axis))
+ el.setAttribute('range', '{} {}'.format(dof.limits[0], dof.limits[1]))
+ return el
+
+
+def export_to_xml(
+ doc: minidom.Document,
+ objects: Sequence[blender_scene.ObjectRef],
+ armature_freejoint: bool,
+) -> minidom.Element:
+ """Converts Blender scene objects to Mujoco scene tree nodes."""
+ root = doc.createElement('worldbody')
+ parent_elements = {blender_scene.NoneRef: root}
+
+ for obj in objects:
+ # Build a subtree corresponding to this Blender object.
+ element = None
+ if obj.is_armature:
+ element = body_builder(doc, obj)
+ if armature_freejoint:
+ element.appendChild(doc.createElement('freejoint'))
+ elif obj.is_mesh:
+ if not obj.parent.is_none and not obj.parent.is_bone:
+ raise RuntimeError(
+ 'Mesh "{}" is parented to an object "{}", which is not a bone. '
+ 'Only mesh->bone parenting is supported at the moment'.format(
+ obj.name, obj.parent.name
+ )
+ )
+ geom_elements = mesh_geom_builder(doc, obj)
+ if len(geom_elements) > 1:
+ # If there's more than one geom, introduce a body to aggregate them
+ # under a single element. That element will then be associated with this
+ # blender object should other blender elements be parented to it.
+ element = doc.createElement('body')
+ for geom_el in geom_elements:
+ element.appendChild(geom_el)
+ elif len(geom_elements) == 1:
+ # Since there's only one geom, consider it the main element
+ element = geom_elements[0]
+ elif obj.is_light:
+ element = light_builder(doc, obj)
+ elif obj.is_bone:
+ element = body_builder(doc, obj)
+ for dof in obj.get_rotation_dofs():
+ element.appendChild(joint_builder(doc, dof, 'hinge'))
+
+ # Inject it into the scene tree.
+ if element:
+ parent = obj.parent
+ if parent:
+ parent_el = parent_elements[parent]
+ parent_el.appendChild(element)
+ parent_elements[obj] = element
+ else:
+ root.appendChild(element)
+
+ return root
diff --git a/dm_control/blender/mujoco_exporter/mujoco_scene_test.py b/dm_control/blender/mujoco_exporter/mujoco_scene_test.py
new file mode 100644
index 00000000..e63868c1
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/mujoco_scene_test.py
@@ -0,0 +1,195 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for mujoco_scene.py."""
+
+import collections
+from unittest import mock
+from xml.dom import minidom
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.blender.mujoco_exporter import blender_scene
+from dm_control.blender.mujoco_exporter import mujoco_scene
+from dm_control.blender.mujoco_exporter import testing
+
+_DEFAULT_PARAMS = dict(armature_freejoint=False)
+
+
+class ConverterTest(parameterized.TestCase):
+
+ def test_convert_armature(self):
+ armature = testing.build_armature('armature')
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_root.createElement = mock.MagicMock(spec=minidom.Element)
+
+ obj = blender_scene.ObjectRef(armature)
+ created_element = mujoco_scene.export_to_xml(
+ xml_root, [obj], **_DEFAULT_PARAMS
+ )
+
+ xml_root.createElement.assert_any_call('body')
+ created_element.setAttribute.assert_any_call('name', 'armature')
+
+ def test_convert_bone_without_constraints_creates_a_body(self):
+ bone = testing.build_bone('bone')
+ armature = testing.build_armature('armature', bones=[bone])
+
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_root.createElement.side_effect = mock.MagicMock(spec=minidom.Element)
+
+ bone_obj = blender_scene.ObjectRef(armature, bone)
+ armature_obj = blender_scene.ObjectRef(armature)
+ created_element = mujoco_scene.export_to_xml(
+ xml_root, [armature_obj, bone_obj], **_DEFAULT_PARAMS
+ )
+
+ xml_root.createElement.assert_any_call('body')
+ created_element.setAttribute.assert_any_call('name', 'bone_armature')
+
+ @parameterized.parameters([
+ # Point light
+ dict(
+ args=dict(lin_att=1, quad_att=2, shadow=True, light_type='POINT'),
+ out=dict(castshadow='true', attenuation='0 1 2', directional='false'),
+ ),
+ # Spot light
+ dict(
+ args=dict(lin_att=3, quad_att=4, shadow=True, light_type='SPOT'),
+ out=dict(castshadow='true', attenuation='0 3 4', directional='true'),
+ ),
+ # Directional light
+ dict(
+ args=dict(lin_att=5, quad_att=6, shadow=True, light_type='SUN'),
+ out=dict(castshadow='true', attenuation='0 5 6', directional='true'),
+ ),
+ # Shadows off
+ dict(
+ args=dict(lin_att=7, quad_att=8, shadow=False, light_type='SUN'),
+ out=dict(castshadow='false', attenuation='0 7 8', directional='true'),
+ ),
+ ])
+ def test_convert_light(self, args, out):
+ light = testing.build_light('light', **args)
+
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_root.createElement = mock.MagicMock(spec=minidom.Element)
+
+ obj = blender_scene.ObjectRef(light)
+ created_element = mujoco_scene.export_to_xml(
+ xml_root, [obj], **_DEFAULT_PARAMS
+ )
+
+ xml_root.createElement.assert_any_call('light')
+ created_element.setAttribute.assert_any_call('name', 'light')
+ for k, v in out.items():
+ created_element.setAttribute.assert_any_call(k, v)
+
+ def test_exported_mesh_references_mesh_asset_by_name(self):
+ mesh = testing.build_mesh('mesh_asset_name')
+ mesh_obj = testing.build_mesh_object('mesh_object_name', mesh=mesh)
+
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_root.createElement = mock.MagicMock(spec=minidom.Element)
+
+ obj = blender_scene.ObjectRef(mesh_obj)
+ created_element = mujoco_scene.export_to_xml(
+ xml_root, [obj], **_DEFAULT_PARAMS
+ )
+
+ xml_root.createElement.assert_any_call('geom')
+ created_element.setAttribute.assert_any_call('name', 'mesh_object_name')
+ created_element.setAttribute.assert_any_call('mesh', 'mesh_asset_name')
+
+ def test_mesh_with_no_material_keeps_its_name(self):
+ mesh = testing.build_mesh('mesh')
+ mesh_obj = testing.build_mesh_object('mesh_obj', mesh=mesh)
+ obj = blender_scene.ObjectRef(mesh_obj)
+
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_elem = mock.MagicMock(spec=minidom.Element)
+ xml_root.createElement.return_value = xml_elem
+ mujoco_scene.export_to_xml(xml_root, [obj], **_DEFAULT_PARAMS)
+
+ xml_elem.setAttribute.assert_any_call('name', 'mesh_obj')
+ xml_elem.setAttribute.assert_any_call('mesh', 'mesh')
+
+ def test_mesh_with_materials_builds_multiple_geoms(self):
+ mesh = testing.build_mesh(
+ 'mesh',
+ materials=[
+ testing.build_material('mat_1'),
+ testing.build_material('mat_2'),
+ ],
+ faces=[[0, 1, 2], [3, 4, 5]],
+ vert_co=[
+ [0.1, 0.1, 0.1],
+ [0.2, 0.2, 0.2],
+ [0.3, 0.3, 0.3],
+ [0.4, 0.4, 0.4],
+ [0.5, 0.5, 0.5],
+ [0.6, 0.6, 0.6],
+ ],
+ mat_ids=[0, 1],
+ )
+ mesh_obj = testing.build_mesh_object('mesh_obj', mesh=mesh)
+
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_elem = mock.MagicMock(spec=minidom.Element)
+ xml_root.createElement.return_value = xml_elem
+
+ obj = blender_scene.ObjectRef(mesh_obj)
+ mujoco_scene.export_to_xml(xml_root, [obj], **_DEFAULT_PARAMS)
+
+ xml_elem.setAttribute.assert_any_call('name', 'mesh_obj_mat_1')
+ xml_elem.setAttribute.assert_any_call('name', 'mesh_obj_mat_2')
+ xml_elem.setAttribute.assert_any_call('mesh', 'mesh_mat_1')
+ xml_elem.setAttribute.assert_any_call('mesh', 'mesh_mat_2')
+
+ def test_empty_submeshes_are_not_exported(self):
+ mesh = testing.build_mesh(
+ 'mesh',
+ materials=[
+ testing.build_material('mat_1'),
+ testing.build_material('mat_2'),
+ ],
+ faces=[[0, 1, 2], [3, 4, 5]],
+ vert_co=[
+ [0.1, 0.1, 0.1],
+ [0.2, 0.2, 0.2],
+ [0.3, 0.3, 0.3],
+ [0.4, 0.4, 0.4],
+ [0.5, 0.5, 0.5],
+ [0.6, 0.6, 0.6],
+ ],
+ mat_ids=[0, 0],
+ ) # All faces reference material 0
+ mesh_obj = testing.build_mesh_object('mesh_obj', mesh=mesh)
+
+ xml_root = mock.MagicMock(spec=minidom.Document)
+ xml_elem = mock.MagicMock(spec=minidom.Element)
+ xml_root.createElement.return_value = xml_elem
+
+ obj = blender_scene.ObjectRef(mesh_obj)
+ mujoco_scene.export_to_xml(xml_root, [obj], **_DEFAULT_PARAMS)
+
+ create_element_call_counts = collections.Counter(
+ [m_[0][0] for m_ in xml_root.createElement.call_args_list]
+ )
+ self.assertEqual(create_element_call_counts['geom'], 1)
+ xml_elem.setAttribute.assert_any_call('name', 'mesh_obj_mat_1')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/blender/mujoco_exporter/testing.py b/dm_control/blender/mujoco_exporter/testing.py
new file mode 100644
index 00000000..c325c38a
--- /dev/null
+++ b/dm_control/blender/mujoco_exporter/testing.py
@@ -0,0 +1,197 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Testing utilities."""
+from typing import Any, Optional, Sequence
+from unittest import mock
+
+from dm_control.blender.fake_core import bpy
+
+
+class FakePropCollection:
+ """Collection that simulates bpy_prop_collection.
+
+ Armature's bone collection is modeled using it, and the tested code depends
+ on some of its features - namely the conversion to list produces dict
+ values instead of its keys, as well as the collection behaving like a dict
+ otherwise.
+
+ @see
+ blender_scene.py : map_blender_tree - iterating over armature.data.bones
+ blender_scene.py : ObjectRef.parent.object_parent - dereferencing bone
+ by its name armature.data.bones[obj.parent_bone]
+ """
+
+ def __init__(self, objects_with_name):
+ if objects_with_name:
+ self._keys = [i.name for i in objects_with_name]
+ else:
+ self._keys = []
+ self._values = objects_with_name or []
+
+ def __len__(self):
+ return len(self._values)
+
+ def __getitem__(self, idx):
+ if isinstance(idx, str):
+ idx = self._keys.index(idx)
+ return self._values[idx]
+
+ def __iter__(self):
+ return iter(self._values)
+
+ def __contains__(self, key):
+ return key in self._keys
+
+ def keys(self):
+ return self._keys
+
+ def values(self):
+ return self._values
+
+ def items(self):
+ return list(zip(self._keys, self._values))
+
+
+def build_armature(
+ name: str,
+ parent: Optional[Any] = None,
+ parent_bone: Optional[str] = None,
+ bones: Optional[Any] = None,
+) -> ...:
+ """TBD."""
+ obj = mock.MagicMock(spec=bpy.types.Object, type='ARMATURE')
+ obj.name = name
+ obj.parent = parent
+ obj.parent_bone = parent_bone
+ obj.data = mock.MagicMock()
+ obj.data.bones = FakePropCollection(bones)
+ obj.pose = mock.MagicMock()
+ obj.pose.bones = FakePropCollection(bones)
+ return obj
+
+
+def build_bone(name, parent=None, constraint=None):
+ """Builds a mock bone."""
+ bone = mock.MagicMock(spec=bpy.types.Bone)
+ bone.name = name
+ bone.parent = parent
+ bone.is_in_ik_chain = constraint is not None
+ bone.lock_ik_x = constraint.lock_x if constraint else False
+ bone.use_ik_limit_x = constraint.use_limit_x if constraint else False
+ bone.ik_min_x = constraint.min_x if constraint else 0.0
+ bone.ik_max_x = constraint.max_x if constraint else 0.0
+ bone.lock_ik_y = constraint.lock_y if constraint else False
+ bone.use_ik_limit_y = constraint.use_limit_y if constraint else False
+ bone.ik_min_y = constraint.min_y if constraint else 0.0
+ bone.ik_max_y = constraint.max_y if constraint else 0.0
+ bone.lock_ik_z = constraint.lock_z if constraint else False
+ bone.use_ik_limit_z = constraint.use_limit_z if constraint else False
+ bone.ik_min_z = constraint.min_z if constraint else 0.0
+ bone.ik_max_z = constraint.max_z if constraint else 0.0
+ return bone
+
+
+def build_rotation_constraint(
+ lock_x=False, use_limit_x=False, min_x=0, max_x=1,
+ lock_y=False, use_limit_y=False, min_y=0, max_y=1,
+ lock_z=False, use_limit_z=False, min_z=0, max_z=1):
+ """Builds a mock rotation constraint."""
+ c = mock.MagicMock()
+ c.lock_x = lock_x
+ c.use_limit_x = use_limit_x
+ c.min_x = min_x
+ c.max_x = max_x
+ c.lock_y = lock_y
+ c.use_limit_y = use_limit_y
+ c.min_y = min_y
+ c.max_y = max_y
+ c.lock_z = lock_z
+ c.use_limit_z = use_limit_z
+ c.min_z = min_z
+ c.max_z = max_z
+ return c
+
+
+def build_mesh(
+ name, materials=None, faces=None, vert_co=None, normals=None, mat_ids=None
+):
+ """Builds a mock triangle mesh."""
+ materials = materials or []
+ faces = faces or [[0, 1, 2]]
+ vert_co = vert_co or [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]]
+ normals = normals or vert_co
+ mat_ids = mat_ids or [0] * len(faces)
+
+ obj = mock.MagicMock(spec=bpy.types.Mesh)
+ obj.name = name
+ obj.materials = materials
+ obj.uv_layers = mock.MagicMock()
+ obj.loop_triangles = [
+ mock.MagicMock(vertices=f, material_index=i)
+ for f, i in zip(faces, mat_ids)
+ ]
+ obj.vertices = [
+ mock.MagicMock(co=c, normal=n) for c, n in zip(vert_co, normals)
+ ]
+
+ return obj
+
+
+def build_mesh_object(
+ name: str,
+ mesh: Optional[Any] = None,
+ parent: Optional[Any] = None,
+ parent_bone: Optional[str] = None) -> ...:
+ """Builds a mock object with a mesh assigned to it."""
+ obj = mock.MagicMock(spec=bpy.types.Object, type='MESH')
+ obj.name = name
+ obj.parent = parent
+ obj.parent_bone = parent_bone
+ obj.data = mesh or build_mesh(name)
+ return obj
+
+
+def build_light(name, light_type, lin_att, quad_att, shadow):
+ """Builds a mock light."""
+ obj = mock.MagicMock(spec=bpy.types.Object, type='LIGHT')
+ obj.name = name
+ obj.parent = None
+ obj.parent_bone = None
+ obj.data = mock.MagicMock(spec=bpy.types.Light)
+ obj.data.type = light_type
+ obj.data.linear_attenuation = lin_att
+ obj.data.quadratic_attenuation = quad_att
+ obj.data.use_shadow = shadow
+ return obj
+
+
+def build_material(
+ name: str,
+ color: Sequence[int] = (1, 1, 1, 1),
+ specular: float = 0.5,
+ metallic: float = 0.0,
+ roughness: float = 0.5,
+ use_backface_culling: bool = False,
+):
+ """Builds a mock material definition."""
+ obj = mock.MagicMock(spec=bpy.types.Material)
+ obj.name = name
+ obj.diffuse_color = color
+ obj.specular_intensity = specular
+ obj.metallic = metallic
+ obj.roughness = roughness
+ obj.use_backface_culling = use_backface_culling
+ return obj
diff --git a/dm_control/composer/__init__.py b/dm_control/composer/__init__.py
new file mode 100644
index 00000000..990d945d
--- /dev/null
+++ b/dm_control/composer/__init__.py
@@ -0,0 +1,33 @@
+# Copyright 2018-2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module containing abstract base classes for Composer environments."""
+
+from dm_control.composer.arena import Arena
+from dm_control.composer.constants import * # pylint: disable=wildcard-import
+from dm_control.composer.define import cached_property
+from dm_control.composer.define import observable
+from dm_control.composer.entity import Entity
+from dm_control.composer.entity import FreePropObservableMixin
+from dm_control.composer.entity import ModelWrapperEntity
+from dm_control.composer.entity import Observables
+from dm_control.composer.environment import Environment
+from dm_control.composer.environment import EpisodeInitializationError
+from dm_control.composer.environment import HOOK_NAMES
+from dm_control.composer.environment import ObservationPadding
+from dm_control.composer.initializer import Initializer
+from dm_control.composer.robot import Robot
+from dm_control.composer.task import NullTask
+from dm_control.composer.task import Task
diff --git a/dm_control/composer/arena.py b/dm_control/composer/arena.py
new file mode 100644
index 00000000..c63c52a8
--- /dev/null
+++ b/dm_control/composer/arena.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""The base empty arena that defines global settings for Composer."""
+
+import os
+
+from dm_control import mjcf
+from dm_control.composer import entity as entity_module
+
+_ARENA_XML_PATH = os.path.join(os.path.dirname(__file__), 'arena.xml')
+
+
+class Arena(entity_module.Entity):
+ """The base empty arena that defines global settings for Composer."""
+
+ def __init__(self, *args, **kwargs):
+ self._mjcf_root = None # Declare that _mjcf_root exists to allay pytype.
+ super().__init__(*args, **kwargs)
+
+ # _build uses *args and **kwargs rather than named arguments, to get
+ # around a signature-mismatch error from pytype in derived classes.
+
+ def _build(self, *args, **kwargs) -> None:
+ """Initializes this arena.
+
+ The function takes two arguments through args, kwargs:
+ name: A string, the name of this arena. If `None`, use the model name
+ defined in the MJCF file.
+ xml_path: An optional path to an XML file that will override the default
+ composer arena MJCF.
+
+ Args:
+ *args: See above.
+ **kwargs: See above.
+ """
+ if args:
+ name = args[0]
+ else:
+ name = kwargs.get('name', None)
+ if len(args) > 1:
+ xml_path = args[1]
+ else:
+ xml_path = kwargs.get('xml_path', None)
+
+ self._mjcf_root = mjcf.from_path(xml_path or _ARENA_XML_PATH)
+ if name:
+ self._mjcf_root.model = name
+
+ def add_free_entity(self, entity):
+ """Includes an entity in the arena as a free-moving body."""
+ frame = self.attach(entity)
+ frame.add('freejoint')
+ return frame
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
diff --git a/dm_control/composer/arena.xml b/dm_control/composer/arena.xml
new file mode 100644
index 00000000..e8a17743
--- /dev/null
+++ b/dm_control/composer/arena.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/composer/constants.py b/dm_control/composer/constants.py
new file mode 100644
index 00000000..1c459b1e
--- /dev/null
+++ b/dm_control/composer/constants.py
@@ -0,0 +1,19 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining constant values for Composer."""
+
+
+SENSOR_SITES_GROUP = 4
diff --git a/dm_control/composer/define.py b/dm_control/composer/define.py
new file mode 100644
index 00000000..e9b78723
--- /dev/null
+++ b/dm_control/composer/define.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Decorators for Entity methods returning elements and observables."""
+
+import abc
+import threading
+
+
+class cached_property(property): # pylint: disable=invalid-name
+ """A property that is evaluated only once per object instance."""
+
+ def __init__(self, func, doc=None):
+ super().__init__(fget=func, doc=doc)
+ self.lock = threading.RLock()
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ name = self.fget.__name__
+ obj_dict = obj.__dict__
+ try:
+ # Try returning a precomputed value without locking first.
+ # Profiling shows that the lock takes up a non-trivial amount of time.
+ return obj_dict[name]
+ except KeyError:
+ # The value hasn't been computed, now we have to lock.
+ with self.lock:
+ try:
+ # Check again whether another thread has already computed the value.
+ return obj_dict[name]
+ except KeyError:
+ # Otherwise call the function, cache the result, and return it
+ return obj_dict.setdefault(name, self.fget(obj))
+
+
+# A decorator for base.Observables methods returning an observable. This
+# decorator should be used by abstract base classes to indicate sub-classes need
+# to implement a corresponding @observavble annotated method.
+abstract_observable = abc.abstractproperty # pylint: disable=invalid-name
+
+
+class observable(cached_property): # pylint: disable=invalid-name
+ """A decorator for base.Observables methods returning an observable.
+
+ The body of the decorated function is evaluated at Entity construction time
+ and the observable is cached.
+ """
+ pass
diff --git a/dm_control/composer/entity.py b/dm_control/composer/entity.py
new file mode 100644
index 00000000..0a78818f
--- /dev/null
+++ b/dm_control/composer/entity.py
@@ -0,0 +1,605 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining the abstract entity class."""
+
+import abc
+import collections
+import os
+import weakref
+
+from absl import logging
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+_OPTION_KEYS = set(['update_interval', 'buffer_size', 'delay', 'aggregator',
+ 'corruptor', 'enabled'])
+
+_NO_ATTACHMENT_FRAME = 'No attachment frame found.'
+
+
+# The component order differs from that used by the open-source `tf` package.
+def _multiply_quaternions(quat1, quat2):
+ result = np.empty_like(quat1)
+ mjbindings.mjlib.mju_mulQuat(result, quat1, quat2)
+ return result
+
+
+def _rotate_vector(vec, quat):
+ """Rotates a vector by the given quaternion."""
+ result = np.empty_like(vec)
+ mjbindings.mjlib.mju_rotVecQuat(result, vec, quat)
+ return result
+
+
+class _ObservableKeys:
+ """Helper object that implements the `observables.dict_keys` functionality."""
+
+ def __init__(self, entity, observables):
+ self._entity = entity
+ self._observables = observables
+
+ def __getattr__(self, name):
+ try:
+ model_identifier = self._entity.mjcf_model.full_identifier
+ except AttributeError as exc:
+ raise ValueError(
+ 'cannot retrieve the full identifier of mjcf_model') from exc
+ return os.path.join(model_identifier, name)
+
+ def __dir__(self):
+ out = set(self._observables.keys())
+ out.update(dir(super()))
+ return list(out)
+
+
+class Observables:
+ """Base-class for Entity observables.
+
+ Subclasses should declare getter methods annotated with @define.observable
+ decorator and returning an observable object.
+ """
+
+ def __init__(self, entity):
+ self._entity = weakref.proxy(entity)
+
+ self._observables = collections.OrderedDict()
+ self._keys_helper = _ObservableKeys(self._entity, self._observables)
+
+ # Ensure consistent ordering.
+ for attr_name in sorted(dir(type(self))):
+ type_attr = getattr(type(self), attr_name)
+ if isinstance(type_attr, define.observable):
+ self._observables[attr_name] = getattr(self, attr_name)
+
+ @property
+ def dict_keys(self):
+ return self._keys_helper
+
+ def as_dict(self, fully_qualified=True):
+ """Returns an OrderedDict of observables belonging to this Entity.
+
+ The returned observables will include any added using the _add_observable
+ method, as well as any generated by a method decorated with the
+ @define.observable annotation.
+
+ Args:
+ fully_qualified: (bool) Whether the dict keys should be prefixed with the
+ parent entity's full model identifier.
+ """
+
+ if fully_qualified:
+ # We need to make sure that this property doesn't raise an AttributeError,
+ # otherwise __getattr__ is executed and we get a very funky error.
+ try:
+ model_identifier = self._entity.mjcf_model.full_identifier
+ except AttributeError as exc:
+ raise ValueError(
+ 'Cannot retrieve the full identifier of mjcf_model.') from exc
+
+ return collections.OrderedDict(
+ [(os.path.join(model_identifier, name), observable)
+ for name, observable in self._observables.items()])
+ else:
+ # Return a copy to prevent dict being edited.
+ return self._observables.copy()
+
+ def get_observable(self, name, name_fully_qualified=False):
+ """Returns the observable with the given name.
+
+ Args:
+ name: (str) The identifier of the observable.
+ name_fully_qualified: (bool) Whether the provided name is prefixed by the
+ model's full identifier.
+ """
+
+ if name_fully_qualified:
+ try:
+ model_identifier = self._entity.mjcf_model.full_identifier
+ except AttributeError as exc:
+ raise ValueError(
+ 'Cannot retrieve the full identifier of mjcf_model.') from exc
+ return self._observables[name.replace(model_identifier, '')]
+ else:
+ return self._observables[name]
+
+ def set_options(self, options):
+ """Configure Observables with an options dict.
+
+ Args:
+ options: A dict of dicts of configuration options keyed on
+ observable names, or a dict of configuration options, which will
+ propagate those options to all observables.
+ """
+ if options is None:
+ options = {}
+ elif options.keys() and set(options.keys()).issubset(_OPTION_KEYS):
+ options = dict([(key, options) for key in self._observables.keys()])
+
+ for obs_key, obs_options in options.items():
+ try:
+ obs = self._observables[obs_key]
+ except KeyError as exc:
+ raise KeyError('No observable with name {!r}'.format(obs_key)) from exc
+ obs.configure(**obs_options)
+
+ def enable_all(self):
+ """Enable all observables of this entity."""
+ for obs in self._observables.values():
+ obs.enabled = True
+
+ def disable_all(self):
+ """Disable all observables of this entity."""
+ for obs in self._observables.values():
+ obs.enabled = False
+
+ def add_observable(self, name, observable, enabled=True):
+ self._observables[name] = observable
+ self._observables[name].enabled = enabled
+
+
+class FreePropObservableMixin(metaclass=abc.ABCMeta):
+ """Enforce observables of a free-moving object."""
+
+ @property
+ @abc.abstractmethod
+ def position(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def orientation(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def linear_velocity(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def angular_velocity(self):
+ pass
+
+
+class Entity(metaclass=abc.ABCMeta):
+ """The abstract base class for an entity in a Composer environment."""
+
+ def __init__(self, *args, **kwargs):
+ """Entity constructor.
+
+ Subclasses should not override this method, instead implement a _build
+ method.
+
+ Args:
+ *args: Arguments passed through to the _build method.
+ **kwargs: Keyword arguments. Passed through to the _build method, apart
+ from the following.
+ `observable_options`: A dictionary of Observable
+ configuration options.
+ """
+ self._post_init_hooks = []
+
+ self._parent = None
+ self._attached = []
+
+ try:
+ observable_options = kwargs.pop('observable_options')
+ except KeyError:
+ observable_options = None
+
+ self._build(*args, **kwargs)
+ self._observables = self._build_observables()
+ self._observables.set_options(observable_options)
+
+ @abc.abstractmethod
+ def _build(self, *args, **kwargs):
+ """Entity initialization method to be overridden by subclasses."""
+ raise NotImplementedError
+
+ def _build_observables(self):
+ """Entity observables initialization method.
+
+ Returns:
+ An object subclassing the Observables class.
+ """
+ return Observables(self)
+
+ def iter_entities(self, exclude_self=False):
+ """An iterator that recursively iterates through all attached entities.
+
+ Args:
+ exclude_self: (optional) Whether to exclude this `Entity` itself from the
+ iterator.
+
+ Yields:
+ If `exclude_self` is `False`, the first value yielded is this Entity
+ itself. The following Entities are then yielded recursively in a
+ depth-first fashion, following the order in which the Entities are
+ attached.
+ """
+ if not exclude_self:
+ yield self
+ for attached_entity in self._attached:
+ for attached_entity_of_attached_entity in attached_entity.iter_entities():
+ yield attached_entity_of_attached_entity
+
+ @property
+ def observables(self):
+ """The observables defined by this entity."""
+ return self._observables
+
+ def initialize_episode_mjcf(self, random_state):
+ """Callback executed when the MJCF model is modified between episodes."""
+ pass
+
+ def after_compile(self, physics, random_state):
+ """Callback executed after the Mujoco Physics is recompiled."""
+ pass
+
+ def initialize_episode(self, physics, random_state):
+ """Callback executed during episode initialization."""
+ pass
+
+ def before_step(self, physics, random_state):
+ """Callback executed before an agent control step."""
+ pass
+
+ def before_substep(self, physics, random_state):
+ """Callback executed before a simulation step."""
+ pass
+
+ def after_substep(self, physics, random_state):
+ """A callback which is executed after a simulation step."""
+ pass
+
+ def after_step(self, physics, random_state):
+ """Callback executed after an agent control step."""
+ pass
+
+ @property
+ @abc.abstractmethod
+ def mjcf_model(self):
+ raise NotImplementedError
+
+ def attach(self, entity, attach_site=None):
+ """Attaches an `Entity` without any additional degrees of freedom.
+
+ Args:
+ entity: The `Entity` to attach.
+ attach_site: (optional) The site to which to attach the entity's model. If
+ not set, defaults to self.attachment_site.
+
+ Returns:
+ The frame of the attached model.
+ """
+
+ if attach_site is None:
+ attach_site = self.attachment_site
+
+ frame = attach_site.attach(entity.mjcf_model)
+ self._attached.append(entity)
+ entity._parent = weakref.ref(self) # pylint: disable=protected-access
+ return frame
+
+ def detach(self):
+ """Detaches this entity if it has previously been attached."""
+ if self._parent is not None:
+ parent = self._parent() # pylint: disable=not-callable
+ if parent: # Weakref might dereference to None during garbage collection.
+ self.mjcf_model.detach()
+ parent._attached.remove(self) # pylint: disable=protected-access
+ self._parent = None
+ else:
+ raise RuntimeError('Cannot detach an entity that is not attached.')
+
+ @property
+ def parent(self):
+ """Returns the `Entity` to which this entity is attached, or `None`."""
+ return self._parent() if self._parent else None # pylint: disable=not-callable
+
+ @property
+ def attachment_site(self):
+ return self.mjcf_model
+
+ @property
+ def root_body(self):
+ if self.parent:
+ return mjcf.get_attachment_frame(self.mjcf_model)
+ else:
+ return self.mjcf_model.worldbody
+
+ def global_vector_to_local_frame(self, physics, vec_in_world_frame):
+ """Linearly transforms a world-frame vector into entity's local frame.
+
+ Note that this function does not perform an affine transformation of the
+ vector. In other words, the input vector is assumed to be specified with
+ respect to the same origin as this entity's local frame. This function
+ can also be applied to matrices whose innermost dimensions are either 2 or
+ 3. In this case, a matrix with the same leading dimensions is returned
+ where the innermost vectors are replaced by their values computed in the
+ local frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ vec_in_world_frame: A NumPy array with last dimension of shape (2,) or
+ (3,) that represents a vector quantity in the world frame.
+
+ Returns:
+ The same quantity as `vec_in_world_frame` but reexpressed in this
+ entity's local frame. The returned np.array has the same shape as
+ np.asarray(vec_in_world_frame).
+
+ Raises:
+ ValueError: if `vec_in_world_frame` does not have shape ending with (2,)
+ or (3,).
+ """
+ vec_in_world_frame = np.asarray(vec_in_world_frame)
+
+ xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
+ # The ordering of the np.dot is such that the transformation holds for any
+ # matrix whose final dimensions are (2,) or (3,).
+ if vec_in_world_frame.shape[-1] == 2:
+ return np.dot(vec_in_world_frame, xmat[:2, :2])
+ elif vec_in_world_frame.shape[-1] == 3:
+ return np.dot(vec_in_world_frame, xmat)
+ else:
+ raise ValueError('`vec_in_world_frame` should have shape with final '
+ 'dimension 2 or 3: got {}'.format(
+ vec_in_world_frame.shape))
+
+ def global_xmat_to_local_frame(self, physics, xmat):
+ """Transforms another entity's `xmat` into this entity's local frame.
+
+ This function takes another entity's (E) xmat, which is an SO(3) matrix
+ from E's frame to the world frame, and turns it to a matrix that transforms
+ from E's frame into this entity's local frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ xmat: A NumPy array of shape (3, 3) or (9,) that represents another
+ entity's xmat.
+
+ Returns:
+ The `xmat` reexpressed in this entity's local frame. The returned
+ np.array has the same shape as np.asarray(xmat).
+
+ Raises:
+ ValueError: if `xmat` does not have shape (3, 3) or (9,).
+ """
+ xmat = np.asarray(xmat)
+
+ input_shape = xmat.shape
+ if xmat.shape == (9,):
+ xmat = np.reshape(xmat, (3, 3))
+
+ self_xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
+ if xmat.shape == (3, 3):
+ return np.reshape(np.dot(self_xmat.T, xmat), input_shape)
+ else:
+ raise ValueError('`xmat` should have shape (3, 3) or (9,): got {}'.format(
+ xmat.shape))
+
+ def get_pose(self, physics):
+ """Get the position and orientation of this entity relative to its parent.
+
+ Note that the semantics differ slightly depending on whether or not the
+ entity has a free joint:
+
+ * If it has a free joint the position and orientation are always given in
+ global coordinates.
+ * If the entity is fixed or attached with a different joint type then the
+ position and orientation are given relative to the parent frame.
+
+ For entities that are either attached directly to the worldbody, or to other
+ entities that are positioned at the global origin (e.g. the arena) the
+ global and relative poses are equivalent.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+
+ Returns:
+ A 2-tuple where the first entry is a (3,) numpy array representing the
+ position and the second is a (4,) numpy array representing orientation as
+ a quaternion.
+
+ Raises:
+ RuntimeError: If the entity is not attached.
+ """
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint:
+ position = physics.bind(root_joint).qpos[:3]
+ quaternion = physics.bind(root_joint).qpos[3:]
+ else:
+ attachment_frame = mjcf.get_attachment_frame(self.mjcf_model)
+ if attachment_frame is None:
+ raise RuntimeError(_NO_ATTACHMENT_FRAME)
+ position = physics.bind(attachment_frame).pos
+ quaternion = physics.bind(attachment_frame).quat
+ return position, quaternion
+
+ def set_pose(self, physics, position=None, quaternion=None):
+ """Sets position and/or orientation of this entity relative to its parent.
+
+ If the entity is attached with a free joint, this method will set the
+ respective DoFs of the joint. If the entity is either fixed or attached with
+ a different joint type, this method will update the position and/or
+ quaternion of the attachment frame.
+
+ Note that the semantics differ slightly between the two cases: the DoFs of a
+ free body are specified in global coordinates, whereas the position of a
+ non-free body is specified in relative coordinates with respect to the
+ parent frame. However, for entities that are either attached directly to the
+ worldbody, or to other entities that are positioned at the global origin
+ (e.g. the arena), there is no difference between the two cases.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ position: (optional) A NumPy array of size 3.
+ quaternion: (optional) A NumPy array of size 4.
+
+ Raises:
+ RuntimeError: If the entity is not attached.
+ """
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint:
+ if position is not None:
+ physics.bind(root_joint).qpos[:3] = position
+ if quaternion is not None:
+ normalised_quaternion = quaternion / np.linalg.norm(quaternion)
+ physics.bind(root_joint).qpos[3:] = normalised_quaternion
+ else:
+ attachment_frame = mjcf.get_attachment_frame(self.mjcf_model)
+ if attachment_frame is None:
+ raise RuntimeError(_NO_ATTACHMENT_FRAME)
+ if position is not None:
+ physics.bind(attachment_frame).pos = position
+ if quaternion is not None:
+ normalised_quaternion = quaternion / np.linalg.norm(quaternion)
+ physics.bind(attachment_frame).quat = normalised_quaternion
+
+ def shift_pose(self,
+ physics,
+ position=None,
+ quaternion=None,
+ rotate_velocity=False):
+ """Shifts the position and/or orientation from its current configuration.
+
+ This is a convenience function that performs the same operation as
+ `set_pose`, but where the specified `position` is added to the current
+ position, and the specified `quaternion` is premultiplied to the current
+ quaternion.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ position: (optional) A NumPy array of size 3.
+ quaternion: (optional) A NumPy array of size 4.
+ rotate_velocity: (optional) A bool, whether to shift the current linear
+ velocity along with the pose. This will rotate the current linear
+ velocity, which is expressed relative to the world frame. The angular
+ velocity, which is expressed relative to the local frame is left
+ unchanged.
+
+ Raises:
+ RuntimeError: If the entity is not attached.
+ """
+ current_position, current_quaternion = self.get_pose(physics)
+ new_position, new_quaternion = None, None
+ if position is not None:
+ new_position = current_position + position
+ if quaternion is not None:
+ quaternion = np.asarray(quaternion, dtype=np.float64)
+ new_quaternion = _multiply_quaternions(quaternion, current_quaternion)
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint and rotate_velocity:
+ # Rotate the linear velocity. The angular velocity (qvel[3:])
+ # is left unchanged, as it is expressed in the local frame.
+ # When rotatating the body frame the angular velocity already
+ # tracks the rotation but the linear velocity does not.
+ velocity = physics.bind(root_joint).qvel[:3]
+ rotated_velocity = _rotate_vector(velocity, quaternion)
+ self.set_velocity(physics, rotated_velocity)
+ self.set_pose(physics, new_position, new_quaternion)
+
+ def get_velocity(self, physics):
+ """Gets the linear and angular velocity of this free entity.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+
+ Returns:
+ A 2-tuple where the first entry is a (3,) numpy array representing the
+ linear velocity and the second is a (3,) numpy array representing the
+ angular velocity.
+
+ """
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint:
+ velocity = physics.bind(root_joint).qvel[:3]
+ angular_velocity = physics.bind(root_joint).qvel[3:]
+ return velocity, angular_velocity
+ else:
+ raise ValueError('get_velocity cannot be used on a non-free entity')
+
+ def set_velocity(self, physics, velocity=None, angular_velocity=None):
+ """Sets the linear velocity and/or angular velocity of this free entity.
+
+ If the entity is attached with a free joint, this method will set the
+ respective DoFs of the joint. Otherwise a warning is logged.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ velocity: (optional) A NumPy array of size 3 specifying the
+ linear velocity.
+ angular_velocity: (optional) A NumPy array of size 3 specifying the
+ angular velocity
+ """
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint:
+ if velocity is not None:
+ physics.bind(root_joint).qvel[:3] = velocity
+ if angular_velocity is not None:
+ physics.bind(root_joint).qvel[3:] = angular_velocity
+ else:
+ logging.warning('Cannot set velocity on Entity with no free joint.')
+
+ def configure_joints(self, physics, position):
+ """Configures this entity's internal joints.
+
+ The default implementation of this method simply sets the `qpos` of all
+ joints in this entity to the values specified in the `position` argument.
+ Entity subclasses with actuated joints may override this method to achieve a
+ stable reconfiguration of joint positions, for example the control signal
+ of position actuators may be changed to match the new joint positions.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ position: The desired position of this entity's joints.
+ """
+ joints = self.mjcf_model.find_all('joint', exclude_attachments=True)
+ physics.bind(joints).qpos = position
+
+
+class ModelWrapperEntity(Entity):
+ """An entity class that wraps an MJCF model without any additional logic."""
+
+ def _build(self, mjcf_model):
+ self._mjcf_model = mjcf_model
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_model
diff --git a/dm_control/composer/entity_test.py b/dm_control/composer/entity_test.py
new file mode 100644
index 00000000..500967d8
--- /dev/null
+++ b/dm_control/composer/entity_test.py
@@ -0,0 +1,427 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for composer.Entity."""
+
+import itertools
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer import arena
+from dm_control.composer import define
+from dm_control.composer import entity
+from dm_control.composer.observation.observable import base as observable
+import numpy as np
+
+_NO_ROTATION = (1, 0, 0, 0) # Tests support for non-arrays and non-floats.
+_NINETY_DEGREES_ABOUT_X = np.array(
+ [np.cos(np.pi / 4), np.sin(np.pi / 4), 0., 0.])
+_NINETY_DEGREES_ABOUT_Y = np.array(
+ [np.cos(np.pi / 4), 0., np.sin(np.pi / 4), 0.])
+_NINETY_DEGREES_ABOUT_Z = np.array(
+ [np.cos(np.pi / 4), 0., 0., np.sin(np.pi / 4)])
+_FORTYFIVE_DEGREES_ABOUT_X = np.array(
+ [np.cos(np.pi / 8), np.sin(np.pi / 8), 0., 0.])
+
+_TEST_ROTATIONS = [
+ # Triplets of original rotation, new rotation and final rotation.
+ (None, _NO_ROTATION, _NO_ROTATION),
+ (_NO_ROTATION, _NINETY_DEGREES_ABOUT_Z, _NINETY_DEGREES_ABOUT_Z),
+ (_FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Y,
+ np.array([0.65328, 0.2706, 0.65328, -0.2706])),
+]
+
+
+def _param_product(**param_lists):
+ keys, values = zip(*param_lists.items())
+ for combination in itertools.product(*values):
+ yield dict(zip(keys, combination))
+
+
+class TestEntity(entity.Entity):
+ """Simple test entity that does nothing but declare some observables."""
+
+ def _build(self, name='test_entity'):
+ self._mjcf_root = mjcf.element.RootElement(model=name)
+ self._mjcf_root.worldbody.add('geom', type='sphere', size=(0.1,))
+
+ def _build_observables(self):
+ return TestEntityObservables(self)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+
+class TestEntityObservables(entity.Observables):
+ """Trivial observables for the test entity."""
+
+ @define.observable
+ def observable0(self):
+ return observable.Generic(lambda phys: 0.0)
+
+ @define.observable
+ def observable1(self):
+ return observable.Generic(lambda phys: 1.0)
+
+
+class EntityTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.entity = TestEntity()
+
+ def testNumObservables(self):
+ """Tests that the observables dict has the right number of entries."""
+ self.assertLen(self.entity.observables.as_dict(), 2)
+
+ def testObservableNames(self):
+ """Tests that the observables dict keys correspond to the observable names.
+ """
+ obs = self.entity.observables.as_dict()
+ self.assertIn('observable0', obs)
+ self.assertIn('observable1', obs)
+
+ subentity = TestEntity(name='subentity')
+ self.entity.attach(subentity)
+ self.assertIn('subentity/observable0', subentity.observables.as_dict())
+ self.assertEqual(subentity.observables.dict_keys.observable0,
+ 'subentity/observable0')
+ self.assertIn('observable0', dir(subentity.observables.dict_keys))
+ self.assertIn('subentity/observable1', subentity.observables.as_dict())
+ self.assertEqual(subentity.observables.dict_keys.observable1,
+ 'subentity/observable1')
+ self.assertIn('observable1', dir(subentity.observables.dict_keys))
+
+ def testEnableDisableObservables(self):
+ """Test the enabling and disable functionality for observables."""
+ all_obs = self.entity.observables.as_dict()
+
+ self.entity.observables.enable_all()
+ for obs in all_obs.values():
+ self.assertTrue(obs.enabled)
+
+ self.entity.observables.disable_all()
+ for obs in all_obs.values():
+ self.assertFalse(obs.enabled)
+
+ self.entity.observables.observable0.enabled = True
+ self.assertTrue(all_obs['observable0'].enabled)
+
+ def testObservableDefaultOptions(self):
+ corruptor = lambda x: x
+ options = {
+ 'update_interval': 2,
+ 'buffer_size': 10,
+ 'delay': 1,
+ 'aggregator': 'max',
+ 'corruptor': corruptor,
+ 'enabled': True
+ }
+ self.entity.observables.set_options(options)
+
+ for obs in self.entity.observables.as_dict().values():
+ self.assertEqual(obs.update_interval, 2)
+ self.assertEqual(obs.delay, 1)
+ self.assertEqual(obs.buffer_size, 10)
+ self.assertEqual(obs.aggregator, observable.AGGREGATORS['max'])
+ self.assertEqual(obs.corruptor, corruptor)
+ self.assertTrue(obs.enabled)
+
+ def testObservablePartialDefaultOptions(self):
+ options = {'update_interval': 2, 'delay': 1}
+ self.entity.observables.set_options(options)
+
+ for obs in self.entity.observables.as_dict().values():
+ self.assertEqual(obs.update_interval, 2)
+ self.assertEqual(obs.delay, 1)
+ self.assertIsNone(obs.buffer_size)
+ self.assertIsNone(obs.aggregator)
+ self.assertIsNone(obs.corruptor)
+
+ def testObservableOptionsInvalidName(self):
+ options = {'asdf': None}
+ with self.assertRaisesRegex(KeyError, 'No observable with name \'asdf\''):
+ self.entity.observables.set_options(options)
+
+ def testObservableInvalidOptions(self):
+ options = {'observable0': {'asdf': 2}}
+ with self.assertRaisesRegex(AttributeError,
+ 'Cannot add attribute asdf in configure.'):
+ self.entity.observables.set_options(options)
+
+ def testObservableOptions(self):
+ options = {
+ 'observable0': {
+ 'update_interval': 2,
+ 'delay': 3
+ },
+ 'observable1': {
+ 'update_interval': 4,
+ 'delay': 5
+ }
+ }
+ self.entity.observables.set_options(options)
+ observables = self.entity.observables.as_dict()
+ self.assertEqual(observables['observable0'].update_interval, 2)
+ self.assertEqual(observables['observable0'].delay, 3)
+ self.assertIsNone(observables['observable0'].buffer_size)
+ self.assertIsNone(observables['observable0'].aggregator)
+ self.assertIsNone(observables['observable0'].corruptor)
+ self.assertFalse(observables['observable0'].enabled)
+
+ self.assertEqual(observables['observable1'].update_interval, 4)
+ self.assertEqual(observables['observable1'].delay, 5)
+ self.assertIsNone(observables['observable1'].buffer_size)
+ self.assertIsNone(observables['observable1'].aggregator)
+ self.assertIsNone(observables['observable1'].corruptor)
+ self.assertFalse(observables['observable1'].enabled)
+
+ def testObservableOptionsEntityConstructor(self):
+ options = {
+ 'observable0': {
+ 'update_interval': 2,
+ 'delay': 3
+ },
+ 'observable1': {
+ 'update_interval': 4,
+ 'delay': 5
+ }
+ }
+ ent = TestEntity(observable_options=options)
+ observables = ent.observables.as_dict()
+ self.assertEqual(observables['observable0'].update_interval, 2)
+ self.assertEqual(observables['observable0'].delay, 3)
+ self.assertIsNone(observables['observable0'].buffer_size)
+ self.assertIsNone(observables['observable0'].aggregator)
+ self.assertIsNone(observables['observable0'].corruptor)
+ self.assertFalse(observables['observable0'].enabled)
+
+ self.assertEqual(observables['observable1'].update_interval, 4)
+ self.assertEqual(observables['observable1'].delay, 5)
+ self.assertIsNone(observables['observable1'].buffer_size)
+ self.assertIsNone(observables['observable1'].aggregator)
+ self.assertIsNone(observables['observable1'].corruptor)
+ self.assertFalse(observables['observable1'].enabled)
+
+ def testObservablePartialOptions(self):
+ options = {'observable0': {'update_interval': 2, 'delay': 3}}
+ self.entity.observables.set_options(options)
+ observables = self.entity.observables.as_dict()
+ self.assertEqual(observables['observable0'].update_interval, 2)
+ self.assertEqual(observables['observable0'].delay, 3)
+ self.assertIsNone(observables['observable0'].buffer_size)
+ self.assertIsNone(observables['observable0'].aggregator)
+ self.assertIsNone(observables['observable0'].corruptor)
+ self.assertFalse(observables['observable0'].enabled)
+
+ self.assertEqual(observables['observable1'].update_interval, 1)
+ self.assertIsNone(observables['observable1'].delay)
+ self.assertIsNone(observables['observable1'].buffer_size)
+ self.assertIsNone(observables['observable1'].aggregator)
+ self.assertIsNone(observables['observable1'].corruptor)
+ self.assertFalse(observables['observable1'].enabled)
+
+ def testAttach(self):
+ entities = [TestEntity() for _ in range(4)]
+ entities[0].attach(entities[1])
+ entities[1].attach(entities[2])
+ entities[0].attach(entities[3])
+
+ self.assertIsNone(entities[0].parent)
+ self.assertIs(entities[1].parent, entities[0])
+ self.assertIs(entities[2].parent, entities[1])
+ self.assertIs(entities[3].parent, entities[0])
+
+ self.assertIsNone(entities[0].mjcf_model.parent_model)
+ self.assertIs(entities[1].mjcf_model.parent_model, entities[0].mjcf_model)
+ self.assertIs(entities[2].mjcf_model.parent_model, entities[1].mjcf_model)
+ self.assertIs(entities[3].mjcf_model.parent_model, entities[0].mjcf_model)
+
+ self.assertEqual(list(entities[0].iter_entities()), entities)
+
+ def testDetach(self):
+ entities = [TestEntity() for _ in range(4)]
+ entities[0].attach(entities[1])
+ entities[1].attach(entities[2])
+ entities[0].attach(entities[3])
+
+ entities[1].detach()
+ with self.assertRaisesRegex(RuntimeError, 'not attached'):
+ entities[1].detach()
+
+ self.assertIsNone(entities[0].parent)
+ self.assertIsNone(entities[1].parent)
+ self.assertIs(entities[2].parent, entities[1])
+ self.assertIs(entities[3].parent, entities[0])
+
+ self.assertIsNone(entities[0].mjcf_model.parent_model)
+ self.assertIsNone(entities[1].mjcf_model.parent_model)
+ self.assertIs(entities[2].mjcf_model.parent_model, entities[1].mjcf_model)
+ self.assertIs(entities[3].mjcf_model.parent_model, entities[0].mjcf_model)
+
+ self.assertEqual(list(entities[0].iter_entities()),
+ [entities[0], entities[3]])
+
+ def testIterEntitiesExcludeSelf(self):
+ entities = [TestEntity() for _ in range(4)]
+ entities[0].attach(entities[1])
+ entities[1].attach(entities[2])
+ entities[0].attach(entities[3])
+ self.assertEqual(
+ list(entities[0].iter_entities(exclude_self=True)), entities[1:])
+
+ def testGlobalVectorToLocalFrame(self):
+ parent = TestEntity()
+ parent.mjcf_model.worldbody.add(
+ 'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model)
+ physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model)
+
+ # 3D vectors
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [0, 1, 0]),
+ [1, 0, 0], atol=1e-10)
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [-1, 0, 0]),
+ [0, 1, 0], atol=1e-10)
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [0, 0, 1]),
+ [0, 0, 1], atol=1e-10)
+
+ # 2D vectors; z-component is ignored
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [0, 1]),
+ [1, 0], atol=1e-10)
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [-1, 0]),
+ [0, 1], atol=1e-10)
+
+ def testGlobalMatrixToLocalFrame(self):
+ parent = TestEntity()
+ parent.mjcf_model.worldbody.add(
+ 'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model)
+ physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model)
+
+ rotation_atob = np.array([[0, 1, 0], [0, 0, -1], [-1, 0, 0]])
+ ego_rotation_atob = np.array([[0, 0, -1], [0, -1, 0], [-1, 0, 0]])
+
+ np.testing.assert_allclose(
+ self.entity.global_xmat_to_local_frame(physics, rotation_atob),
+ ego_rotation_atob, atol=1e-10)
+
+ flat_rotation_atob = np.reshape(rotation_atob, -1)
+ flat_rotation_ego_atob = np.reshape(ego_rotation_atob, -1)
+ np.testing.assert_allclose(
+ self.entity.global_xmat_to_local_frame(
+ physics, flat_rotation_atob),
+ flat_rotation_ego_atob, atol=1e-10)
+
+ @parameterized.parameters(*_param_product(
+ position=[None, [1., 0., -1.]],
+ quaternion=[None, _FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Z],
+ freejoint=[False, True],
+ ))
+ def testSetPose(self, position, quaternion, freejoint):
+ # Setup entity.
+ test_arena = arena.Arena()
+ subentity = TestEntity(name='subentity')
+ frame = test_arena.attach(subentity)
+ if freejoint:
+ frame.add('freejoint')
+
+ physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
+
+ if quaternion is None:
+ ground_truth_quat = _NO_ROTATION
+ else:
+ ground_truth_quat = quaternion
+
+ if position is None:
+ ground_truth_pos = np.zeros(shape=(3,))
+ else:
+ ground_truth_pos = position
+
+ subentity.set_pose(physics, position=position, quaternion=quaternion)
+
+ np.testing.assert_allclose(physics.bind(frame).xpos, ground_truth_pos)
+ np.testing.assert_allclose(physics.bind(frame).xquat, ground_truth_quat)
+
+ @parameterized.parameters(*_param_product(
+ original_position=[[-2, -1, -1.], [1., 0., -1.]],
+ position=[None, [1., 0., -1.]],
+ original_quaternion=_TEST_ROTATIONS[0],
+ quaternion=_TEST_ROTATIONS[1],
+ expected_quaternion=_TEST_ROTATIONS[2],
+ freejoint=[False, True],
+ ))
+ def testShiftPose(self, original_position, position, original_quaternion,
+ quaternion, expected_quaternion, freejoint):
+ # Setup entity.
+ test_arena = arena.Arena()
+ subentity = TestEntity(name='subentity')
+ frame = test_arena.attach(subentity)
+ if freejoint:
+ frame.add('freejoint')
+
+ physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
+
+ # Set the original position
+ subentity.set_pose(
+ physics, position=original_position, quaternion=original_quaternion)
+
+ if position is None:
+ ground_truth_pos = original_position
+ else:
+ ground_truth_pos = original_position + np.array(position)
+ subentity.shift_pose(physics, position=position, quaternion=quaternion)
+ np.testing.assert_array_equal(physics.bind(frame).xpos, ground_truth_pos)
+
+ updated_quat = physics.bind(frame).xquat
+ np.testing.assert_array_almost_equal(updated_quat, expected_quaternion,
+ 1e-4)
+
+ @parameterized.parameters(False, True)
+ def testShiftPoseWithVelocity(self, rotate_velocity):
+ # Setup entity.
+ test_arena = arena.Arena()
+ subentity = TestEntity(name='subentity')
+ frame = test_arena.attach(subentity)
+ frame.add('freejoint')
+
+ physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
+
+ # Set the original position
+ subentity.set_pose(physics, position=[0., 0., 0.])
+
+ # Set velocity in y dim.
+ subentity.set_velocity(physics, [0., 1., 0.])
+
+ # Rotate the entity around the z axis.
+ subentity.shift_pose(
+ physics, quaternion=[0., 0., 0., 1.], rotate_velocity=rotate_velocity)
+
+ physics.forward()
+ updated_position, _ = subentity.get_pose(physics)
+ if rotate_velocity:
+ # Should not have moved in the y dim.
+ np.testing.assert_array_almost_equal(updated_position[1], 0.)
+ else:
+ # Should not have moved in the x dim.
+ np.testing.assert_array_almost_equal(updated_position[0], 0.)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/environment.py b/dm_control/composer/environment.py
new file mode 100644
index 00000000..18a50dac
--- /dev/null
+++ b/dm_control/composer/environment.py
@@ -0,0 +1,517 @@
+# Copyright 2018-2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""RL environment classes for Composer tasks."""
+
+import enum
+import warnings
+import weakref
+
+from absl import logging
+from dm_control import mjcf
+from dm_control.composer import observation
+from dm_control.rl import control
+import dm_env
+import numpy as np
+
+warnings.simplefilter('always', DeprecationWarning)
+
+_STEPS_LOGGING_INTERVAL = 10000
+
+HOOK_NAMES = ('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode',
+ 'before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+
+_empty_function = lambda: None
+
+
+def _empty_function_with_docstring():
+ """Some docstring."""
+
+_EMPTY_CODE = _empty_function.__code__.co_code
+_EMPTY_WITH_DOCSTRING_CODE = _empty_function_with_docstring.__code__.co_code
+
+
+def _callable_is_trivial(f):
+ return (f.__code__.co_code == _EMPTY_CODE or
+ f.__code__.co_code == _EMPTY_WITH_DOCSTRING_CODE)
+
+
+class ObservationPadding(enum.Enum):
+ INITIAL_VALUE = -1
+ ZERO = 0
+
+
+class EpisodeInitializationError(RuntimeError):
+ """Raised by a `composer.Task` when it fails to initialize an episode."""
+
+
+class _Hook:
+
+ __slots__ = ('entity_hooks', 'extra_hooks')
+
+ def __init__(self):
+ self.entity_hooks = []
+ self.extra_hooks = []
+
+
+class _EnvironmentHooks:
+ """Helper object that scans and memoizes various hooks in a task.
+
+ This object exist to ensure that we do not incur a substantial overhead in
+ calling empty entity hooks in more complicated tasks.
+ """
+
+ __slots__ = (('_task', '_episode_step_count') +
+ tuple('_' + hook_name for hook_name in HOOK_NAMES))
+
+ def __init__(self, task):
+ self._task = task
+ self._episode_step_count = 0
+ for hook_name in HOOK_NAMES:
+ slot_name = '_' + hook_name
+ setattr(self, slot_name, _Hook())
+ self.refresh_entity_hooks()
+
+ def refresh_entity_hooks(self):
+ """Scans and memoizes all non-trivial entity hooks."""
+ for hook_name in HOOK_NAMES:
+ hooks = []
+ for entity in self._task.root_entity.iter_entities():
+ entity_hook = getattr(entity, hook_name)
+ # Ignore any hook that is a no-op to avoid function call overhead.
+ if not _callable_is_trivial(entity_hook):
+ hooks.append(entity_hook)
+ getattr(self, '_' + hook_name).entity_hooks = hooks
+
+ def add_extra_hook(self, hook_name, hook_callable):
+ if hook_name not in HOOK_NAMES:
+ raise ValueError('{!r} is not a valid hook name'.format(hook_name))
+ if not callable(hook_callable):
+ raise ValueError('{!r} is not a callable'.format(hook_callable))
+ getattr(self, '_' + hook_name).extra_hooks.append(hook_callable)
+
+ def initialize_episode_mjcf(self, random_state):
+ self._task.initialize_episode_mjcf(random_state)
+ for entity_hook in self._initialize_episode_mjcf.entity_hooks:
+ entity_hook(random_state)
+ for extra_hook in self._initialize_episode_mjcf.extra_hooks:
+ extra_hook(random_state)
+
+ def after_compile(self, physics, random_state):
+ self._task.after_compile(physics, random_state)
+ for entity_hook in self._after_compile.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._after_compile.extra_hooks:
+ extra_hook(physics, random_state)
+
+ def initialize_episode(self, physics, random_state):
+ self._episode_step_count = 0
+ self._task.initialize_episode(physics, random_state)
+ for entity_hook in self._initialize_episode.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._initialize_episode.extra_hooks:
+ extra_hook(physics, random_state)
+
+ def before_step(self, physics, action, random_state):
+ self._episode_step_count += 1
+ if self._episode_step_count % _STEPS_LOGGING_INTERVAL == 0:
+ logging.info('The current episode has been running for %d steps.',
+ self._episode_step_count)
+ self._task.before_step(physics, action, random_state)
+ for entity_hook in self._before_step.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._before_step.extra_hooks:
+ extra_hook(physics, action, random_state)
+
+ def before_substep(self, physics, action, random_state):
+ self._task.before_substep(physics, action, random_state)
+ for entity_hook in self._before_substep.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hooks in self._before_substep.extra_hooks:
+ extra_hooks(physics, action, random_state)
+
+ def after_substep(self, physics, random_state):
+ self._task.after_substep(physics, random_state)
+ for entity_hook in self._after_substep.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._after_substep.extra_hooks:
+ extra_hook(physics, random_state)
+
+ def after_step(self, physics, random_state):
+ self._task.after_step(physics, random_state)
+ for entity_hook in self._after_step.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._after_step.extra_hooks:
+ extra_hook(physics, random_state)
+
+
+class _CommonEnvironment:
+ """Common components for RL environments."""
+
+ def __init__(self, task, time_limit=float('inf'), random_state=None,
+ n_sub_steps=None,
+ raise_exception_on_physics_error=True,
+ strip_singleton_obs_buffer_dim=False,
+ delayed_observation_padding=ObservationPadding.ZERO,
+ legacy_step: bool = True):
+ """Initializes an instance of `_CommonEnvironment`.
+
+ Args:
+ task: Instance of `composer.base.Task`.
+ time_limit: (optional) A float, the time limit in seconds beyond which an
+ episode is forced to terminate.
+ random_state: Optional, either an int seed or an `np.random.RandomState`
+ object. If None (default), the random number generator will self-seed
+ from a platform-dependent source of entropy.
+ n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per
+ agent control step. New code should instead override the
+ `control_substep` property of the task.
+ raise_exception_on_physics_error: (optional) A boolean, indicating whether
+ `PhysicsError` should be raised as an exception. If `False`, physics
+ errors will result in the current episode being terminated with a
+ warning logged, and a new episode started.
+ strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`,
+ the array shape of observations with `buffer_size == 1` will not have a
+ leading buffer dimension.
+ delayed_observation_padding: (optional) An `ObservationPadding` enum value
+ specifying the padding behavior of the initial buffers for delayed
+ observables. If `ZERO` then the buffer is initially filled with zeroes.
+ If `INITIAL_VALUE` then the buffer is initially filled with the first
+ observation values.
+ legacy_step: If True, steps the state with up-to-date position and
+ velocity dependent fields. See Page 6 of
+ https://arxiv.org/abs/2006.12983 for more information.
+ """
+ if not isinstance(delayed_observation_padding, ObservationPadding):
+ raise ValueError(
+ f'`delayed_observation_padding` should be an `ObservationPadding` '
+ f'enum value: got {delayed_observation_padding}')
+
+ self._task = task
+ if not isinstance(random_state, np.random.RandomState):
+ self._random_state = np.random.RandomState(random_state)
+ else:
+ self._random_state = random_state
+ self._hooks = _EnvironmentHooks(self._task)
+ self._time_limit = time_limit
+ self._raise_exception_on_physics_error = raise_exception_on_physics_error
+ self._strip_singleton_obs_buffer_dim = strip_singleton_obs_buffer_dim
+ self._delayed_observation_padding = delayed_observation_padding
+ self._legacy_step = legacy_step
+
+ if n_sub_steps is not None:
+ warnings.simplefilter('once', DeprecationWarning)
+ warnings.warn('The `n_sub_steps` argument is deprecated. Please override '
+ 'the `control_timestep` property of the task instead.',
+ DeprecationWarning)
+ self._overridden_n_sub_steps = n_sub_steps
+
+ self._recompile_physics_and_update_observables()
+
+ def add_extra_hook(self, hook_name, hook_callable):
+ self._hooks.add_extra_hook(hook_name, hook_callable)
+
+ def _recompile_physics_and_update_observables(self):
+ """Sets up the environment for latest MJCF model from the task."""
+ self._physics_proxy = None
+ self._recompile_physics()
+ if isinstance(self._physics, weakref.ProxyType):
+ self._physics_proxy = self._physics
+ else:
+ self._physics_proxy = weakref.proxy(self._physics)
+
+ if self._overridden_n_sub_steps is not None:
+ self._n_sub_steps = self._overridden_n_sub_steps
+ else:
+ self._n_sub_steps = self._task.physics_steps_per_control_step
+
+ self._hooks.refresh_entity_hooks()
+ self._hooks.after_compile(self._physics_proxy, self._random_state)
+ self._observation_updater = self._make_observation_updater()
+ self._observation_updater.reset(self._physics_proxy, self._random_state)
+
+ def _recompile_physics(self):
+ """Creates a new Physics using the latest MJCF model from the task."""
+ physics = getattr(self, '_physics', None)
+ if physics:
+ physics.free()
+ self._physics = mjcf.Physics.from_mjcf_model(
+ self._task.root_entity.mjcf_model)
+ self._physics.legacy_step = self._legacy_step
+
+ def _make_observation_updater(self):
+ pad_with_initial_value = (
+ self._delayed_observation_padding == ObservationPadding.INITIAL_VALUE)
+ return observation.Updater(
+ self._task.observables, self._task.physics_steps_per_control_step,
+ self._strip_singleton_obs_buffer_dim, pad_with_initial_value)
+
+ @property
+ def physics(self):
+ """Returns a `weakref.ProxyType` pointing to the current `mjcf.Physics`.
+
+ Note that the underlying `mjcf.Physics` will be destroyed whenever the MJCF
+ model is recompiled or environment.close() is called. It is therefore unsafe
+ for external objects to hold a reference to `environment.physics`.
+ Attempting to access attributes of a dead `Physics` instance will result in
+ a `ReferenceError`.
+ """
+ return self._physics_proxy
+
+ @property
+ def task(self):
+ return self._task
+
+ @property
+ def random_state(self):
+ return self._random_state
+
+ def control_timestep(self):
+ """Returns the interval between agent actions in seconds."""
+ if self._overridden_n_sub_steps is not None:
+ return self.physics.timestep() * self._overridden_n_sub_steps
+ else:
+ return self.task.control_timestep
+
+
+class Environment(_CommonEnvironment, dm_env.Environment):
+ """Reinforcement learning environment for Composer tasks."""
+
+ def __init__(
+ self,
+ task,
+ time_limit=float('inf'),
+ random_state=None,
+ n_sub_steps=None,
+ raise_exception_on_physics_error=True,
+ strip_singleton_obs_buffer_dim=False,
+ max_reset_attempts=1,
+ recompile_mjcf_every_episode=True,
+ fixed_initial_state=False,
+ delayed_observation_padding=ObservationPadding.ZERO,
+ legacy_step: bool = True,
+ ):
+ """Initializes an instance of `Environment`.
+
+ Args:
+ task: Instance of `composer.base.Task`.
+ time_limit: (optional) A float, the time limit in seconds beyond which an
+ episode is forced to terminate.
+ random_state: (optional) an int seed or `np.random.RandomState` instance.
+ n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per
+ agent control step. New code should instead override the
+ `control_substep` property of the task.
+ raise_exception_on_physics_error: (optional) A boolean, indicating whether
+ `PhysicsError` should be raised as an exception. If `False`, physics
+ errors will result in the current episode being terminated with a
+ warning logged, and a new episode started.
+ strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`, the array
+ shape of observations with `buffer_size == 1` will not have a leading
+ buffer dimension.
+ max_reset_attempts: (optional) Maximum number of times to try resetting
+ the environment. If an `EpisodeInitializationError` is raised during
+ this process, an environment reset is reattempted up to this number of
+ times. If this count is exceeded then the most recent exception will be
+ allowed to propagate. Defaults to 1, i.e. no failure is allowed.
+ recompile_mjcf_every_episode: If True will recompile the mjcf model
+ between episodes. This specifically skips the `initialize_episode_mjcf`
+ and `after_compile` steps. This allows a speedup if no changes are made
+ to the model.
+ fixed_initial_state: If True the starting state of every single episode
+ will be the same. Meaning an identical sequence of action will lead to
+ an identical final state. If False, will randomize the starting state at
+ every episode.
+ delayed_observation_padding: (optional) An `ObservationPadding` enum value
+ specifying the padding behavior of the initial buffers for delayed
+ observables. If `ZERO` then the buffer is initially filled with zeroes.
+ If `INITIAL_VALUE` then the buffer is initially filled with the first
+ observation values.
+ legacy_step: If True, steps the state with up-to-date position and
+ velocity dependent fields.
+ """
+ super().__init__(
+ task=task,
+ time_limit=time_limit,
+ random_state=random_state,
+ n_sub_steps=n_sub_steps,
+ raise_exception_on_physics_error=raise_exception_on_physics_error,
+ strip_singleton_obs_buffer_dim=strip_singleton_obs_buffer_dim,
+ delayed_observation_padding=delayed_observation_padding,
+ legacy_step=legacy_step)
+ self._max_reset_attempts = max_reset_attempts
+ self._recompile_mjcf_every_episode = recompile_mjcf_every_episode
+ self._mjcf_never_compiled = True
+ self._fixed_initial_state = fixed_initial_state
+ self._fixed_random_state = self._random_state.get_state()
+ self._reset_next_step = True
+
+ def reset(self):
+ failed_attempts = 0
+ while True:
+ try:
+ return self._reset_attempt()
+ except EpisodeInitializationError as e:
+ failed_attempts += 1
+ if failed_attempts < self._max_reset_attempts:
+ logging.error('Error during episode reset: %s', repr(e))
+ else:
+ raise
+
+ def _reset_attempt(self):
+ if self._recompile_mjcf_every_episode or self._mjcf_never_compiled:
+ if self._fixed_initial_state:
+ self._random_state.set_state(self._fixed_random_state)
+ self._hooks.initialize_episode_mjcf(self._random_state)
+ self._recompile_physics_and_update_observables()
+ self._mjcf_never_compiled = False
+
+ if self._fixed_initial_state:
+ self._random_state.set_state(self._fixed_random_state)
+ with self._physics.reset_context():
+ self._hooks.initialize_episode(self._physics_proxy, self._random_state)
+ self._observation_updater.reset(self._physics_proxy, self._random_state)
+ self._reset_next_step = False
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.FIRST,
+ reward=None,
+ discount=None,
+ observation=self._observation_updater.get_observation())
+
+ # TODO(b/129061424): Remove this method.
+ def step_spec(self):
+ """DEPRECATED: please use `reward_spec` and `discount_spec` instead."""
+ warnings.warn('`step_spec` is deprecated, please use `reward_spec` and '
+ '`discount_spec` instead.', DeprecationWarning)
+ if (self._task.get_reward_spec() is None or
+ self._task.get_discount_spec() is None):
+ raise NotImplementedError
+ return dm_env.TimeStep(
+ step_type=None,
+ reward=self._task.get_reward_spec(),
+ discount=self._task.get_discount_spec(),
+ observation=self._observation_updater.observation_spec(),
+ )
+
+ def step(self, action):
+ """Updates the environment using the action and returns a `TimeStep`."""
+ if self._reset_next_step:
+ self._reset_next_step = False
+ return self.reset()
+
+ self._hooks.before_step(self._physics_proxy, action, self._random_state)
+ self._observation_updater.prepare_for_next_control_step()
+
+ try:
+ for i in range(self._n_sub_steps):
+ self._substep(action)
+ # The final observation update must happen after all the hooks in
+ # `self._hooks.after_step` is called. Otherwise, if any of these hooks
+ # modify the physics state then we might capture an observation that is
+ # inconsistent with the final physics state.
+ if i < self._n_sub_steps - 1:
+ self._observation_updater.update()
+ physics_is_divergent = False
+ except control.PhysicsError as e:
+ if not self._raise_exception_on_physics_error:
+ logging.warning(e)
+ physics_is_divergent = True
+ else:
+ raise
+
+ self._hooks.after_step(self._physics_proxy, self._random_state)
+ self._observation_updater.update()
+
+ if not physics_is_divergent:
+ reward = self._task.get_reward(self._physics_proxy)
+ discount = self._task.get_discount(self._physics_proxy)
+ terminating = (
+ self._task.should_terminate_episode(self._physics_proxy)
+ or self._physics.time() >= self._time_limit
+ )
+ else:
+ reward = 0.0
+ discount = 0.0
+ terminating = True
+
+ obs = self._observation_updater.get_observation()
+
+ if not terminating:
+ return dm_env.TimeStep(dm_env.StepType.MID, reward, discount, obs)
+ else:
+ self._reset_next_step = True
+ return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, obs)
+
+ def _substep(self, action):
+ self._hooks.before_substep(
+ self._physics_proxy, action, self._random_state)
+ self._physics.step()
+ self._hooks.after_substep(self._physics_proxy, self._random_state)
+
+ def close(self):
+ super().close()
+ self._physics.free()
+ self._physics = None
+
+ def action_spec(self):
+ """Returns the action specification for this environment."""
+ return self._task.action_spec(self._physics_proxy)
+
+ def reward_spec(self):
+ """Describes the reward returned by this environment.
+
+ This will be the output of `self.task.reward_spec()` if it is not None,
+ otherwise it will be the default spec returned by
+ `dm_env.Environment.reward_spec()`.
+
+ Returns:
+ A `specs.Array` instance, or a nested dict, list or tuple of
+ `specs.Array`s.
+ """
+ task_reward_spec = self._task.get_reward_spec()
+ if task_reward_spec is not None:
+ return task_reward_spec
+ else:
+ return super().reward_spec()
+
+ def discount_spec(self):
+ """Describes the discount returned by this environment.
+
+ This will be the output of `self.task.discount_spec()` if it is not None,
+ otherwise it will be the default spec returned by
+ `dm_env.Environment.discount_spec()`.
+
+ Returns:
+ A `specs.Array` instance, or a nested dict, list or tuple of
+ `specs.Array`s.
+ """
+ task_discount_spec = self._task.get_discount_spec()
+ if task_discount_spec is not None:
+ return task_discount_spec
+ else:
+ return super().discount_spec()
+
+ def observation_spec(self):
+ """Returns the observation specification for this environment.
+
+ Returns:
+ An `OrderedDict` mapping observation name to `specs.Array` containing
+ observation shape and dtype.
+ """
+ return self._observation_updater.observation_spec()
diff --git a/dm_control/composer/environment_hooks_test.py b/dm_control/composer/environment_hooks_test.py
new file mode 100644
index 00000000..f3a225ef
--- /dev/null
+++ b/dm_control/composer/environment_hooks_test.py
@@ -0,0 +1,40 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for Entity and Task hooks in an Environment."""
+
+from absl.testing import absltest
+from dm_control import composer
+from dm_control.composer import hooks_test_utils
+import numpy as np
+
+
+class EnvironmentHooksTest(hooks_test_utils.HooksTestMixin, absltest.TestCase):
+
+ def testEnvironmentHooksScheduling(self):
+ env = composer.Environment(self.task)
+ for hook_name in composer.HOOK_NAMES:
+ env.add_extra_hook(hook_name, getattr(self.extra_hooks, hook_name))
+ for _ in range(self.num_episodes):
+ with self.track_episode():
+ env.reset()
+ for _ in range(self.steps_per_episode):
+ env.step([0.1, 0.2, 0.3, 0.4])
+ np.testing.assert_array_equal(env.physics.data.ctrl,
+ [0.1, 0.2, 0.3, 0.4])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/environment_test.py b/dm_control/composer/environment_test.py
new file mode 100644
index 00000000..b1900d1f
--- /dev/null
+++ b/dm_control/composer/environment_test.py
@@ -0,0 +1,162 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.composer.environment."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+import dm_env
+import mock
+import numpy as np
+
+
+class DummyTask(composer.NullTask):
+
+ def __init__(self):
+ null_entity = composer.ModelWrapperEntity(mjcf.RootElement())
+ super().__init__(null_entity)
+
+ @property
+ def task_observables(self):
+ time = observable.Generic(lambda physics: physics.time())
+ time.enabled = True
+ return {'time': time}
+
+
+class DummyTaskWithResetFailures(DummyTask):
+
+ def __init__(self, num_reset_failures):
+ super().__init__()
+ self.num_reset_failures = num_reset_failures
+ self.reset_counter = 0
+
+ def initialize_episode_mjcf(self, random_state):
+ self.reset_counter += 1
+
+ def initialize_episode(self, physics, random_state):
+ if self.reset_counter <= self.num_reset_failures:
+ raise composer.EpisodeInitializationError()
+
+
+class DummyTaskWithRandomObservation(composer.NullTask):
+
+ def __init__(self):
+ null_entity = composer.ModelWrapperEntity(mjcf.RootElement())
+ super().__init__(null_entity)
+
+ self._observation = [0.0] * 1000
+
+ def initialize_episode(self, physics, random_state):
+ del physics
+ self._observation = random_state.randint(1000, size=1000)
+
+ @property
+ def task_observables(self):
+ random_int = observable.Generic(lambda physics: self._observation)
+ random_int.enabled = True
+ return {'random_int': random_int}
+
+
+class EnvironmentTest(parameterized.TestCase):
+
+ def test_failed_resets(self):
+ total_reset_failures = 5
+ env_reset_attempts = 2
+ task = DummyTaskWithResetFailures(num_reset_failures=total_reset_failures)
+ env = composer.Environment(task, max_reset_attempts=env_reset_attempts)
+ for _ in range(total_reset_failures // env_reset_attempts):
+ with self.assertRaises(composer.EpisodeInitializationError):
+ env.reset()
+ env.reset() # should not raise an exception
+ self.assertEqual(task.reset_counter, total_reset_failures + 1)
+
+ @parameterized.parameters(
+ dict(name='reward_spec', defined_in_task=True),
+ dict(name='reward_spec', defined_in_task=False),
+ dict(name='discount_spec', defined_in_task=True),
+ dict(name='discount_spec', defined_in_task=False))
+ def test_get_spec(self, name, defined_in_task):
+ task = DummyTask()
+ env = composer.Environment(task)
+ with mock.patch.object(task, 'get_' + name) as mock_task_get_spec:
+ if defined_in_task:
+ expected_spec = mock.Mock()
+ mock_task_get_spec.return_value = expected_spec
+ else:
+ expected_spec = getattr(dm_env.Environment, name)(env)
+ mock_task_get_spec.return_value = None
+ spec = getattr(env, name)()
+ mock_task_get_spec.assert_called_once_with()
+ self.assertSameStructure(spec, expected_spec)
+
+ def test_can_provide_observation(self):
+ task = DummyTask()
+ env = composer.Environment(task)
+ obs = env.reset().observation
+ self.assertLen(obs, 1)
+ np.testing.assert_array_equal(obs['time'], env.physics.time())
+ for _ in range(20):
+ obs = env.step([]).observation
+ self.assertLen(obs, 1)
+ np.testing.assert_array_equal(obs['time'], env.physics.time())
+
+ def test_dont_compile_mjcf_between_episodes(self):
+ class AfterCompileHook(object):
+
+ def __init__(self):
+ self.after_compile_call_count = 0
+
+ def __call__(self, physics, random_state):
+ del physics, random_state
+ self.after_compile_call_count += 1
+
+ after_compile_hook = AfterCompileHook()
+ task = DummyTask()
+ env = composer.Environment(task, recompile_mjcf_every_episode=False)
+ env.add_extra_hook('after_compile', after_compile_hook)
+ env.reset()
+ self.assertEqual(after_compile_hook.after_compile_call_count, 1)
+ for _ in range(4):
+ env.reset()
+ env.step([])
+
+ # Check the hook is not called.
+ self.assertEqual(after_compile_hook.after_compile_call_count, 1)
+
+ def test_fixed_initial_state(self):
+ task = DummyTaskWithRandomObservation()
+ fixed_env = composer.Environment(task, fixed_initial_state=True)
+ non_fixed_env = composer.Environment(task, fixed_initial_state=False)
+ fixed_obs = fixed_env.reset().observation['random_int']
+ non_fixed_obs = non_fixed_env.reset().observation['random_int']
+ for _ in range(3):
+ np.testing.assert_array_equal(
+ fixed_env.reset().observation['random_int'], fixed_obs
+ )
+ self.assertTrue(
+ np.any(
+ np.not_equal(
+ np.asarray(non_fixed_obs),
+ np.asarray(non_fixed_env.reset().observation['random_int']),
+ )
+ )
+ )
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/hooks_test_utils.py b/dm_control/composer/hooks_test_utils.py
new file mode 100644
index 00000000..d32763eb
--- /dev/null
+++ b/dm_control/composer/hooks_test_utils.py
@@ -0,0 +1,323 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utilities for testing environment hooks."""
+
+import collections
+import contextlib
+import inspect
+
+from dm_control import composer
+from dm_control import mjcf
+
+
+def add_bodies_and_actuators(mjcf_model, num_actuators):
+ if num_actuators % 2:
+ raise ValueError('num_actuators is not a multiple of 2')
+ for _ in range(num_actuators // 2):
+ body = mjcf_model.worldbody.add('body')
+ body.add('inertial', pos=[0, 0, 0], mass=1, diaginertia=[1, 1, 1])
+ joint_x = body.add('joint', axis=[1, 0, 0])
+ mjcf_model.actuator.add('position', joint=joint_x)
+ joint_y = body.add('joint', axis=[0, 1, 0])
+ mjcf_model.actuator.add('position', joint=joint_y)
+
+
+class HooksTracker:
+ """Helper class for tracking call order of callbacks."""
+
+ def __init__(self, test_case, physics_timestep, control_timestep,
+ *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.tracked = False
+ self._test_case = test_case
+ self._call_count = collections.defaultdict(lambda: 0)
+ self._physics_timestep = physics_timestep
+ self._physics_steps_per_control_step = (
+ round(int(control_timestep / physics_timestep)))
+
+ mro = inspect.getmro(type(self))
+ self._has_super = mro[mro.index(HooksTracker) + 1] != object
+
+ def assertEqual(self, actual, expected, msg=''):
+ msg = '{}: {}: {!r} != {!r}'.format(type(self), msg, actual, expected)
+ self._test_case.assertEqual(actual, expected, msg)
+
+ def assertHooksNotCalled(self, *hook_names):
+ for hook_name in hook_names:
+ self.assertEqual(
+ self._call_count[hook_name], 0,
+ 'assertHooksNotCalled: hook_name = {!r}'.format(hook_name))
+
+ def assertHooksCalledOnce(self, *hook_names):
+ for hook_name in hook_names:
+ self.assertEqual(
+ self._call_count[hook_name], 1,
+ 'assertHooksCalledOnce: hook_name = {!r}'.format(hook_name))
+
+ def assertCompleteEpisode(self, control_steps):
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+ physics_steps = control_steps * self._physics_steps_per_control_step
+ self.assertEqual(self._call_count['before_step'], control_steps)
+ self.assertEqual(self._call_count['before_substep'], physics_steps)
+ self.assertEqual(self._call_count['after_substep'], physics_steps)
+ self.assertEqual(self._call_count['after_step'], control_steps)
+
+ def assertPhysicsStepCountEqual(self, physics, expected_count):
+ actual_count = int(round(physics.time() / self._physics_timestep))
+ self.assertEqual(actual_count, expected_count)
+
+ def reset_call_counts(self):
+ self._call_count = collections.defaultdict(lambda: 0)
+
+ def initialize_episode_mjcf(self, random_state):
+ """Implements `initialize_episode_mjcf` Composer callback."""
+ if self._has_super:
+ super().initialize_episode_mjcf(random_state)
+ if not self.tracked:
+ return
+ self.assertHooksNotCalled('after_compile',
+ 'initialize_episode',
+ 'before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+ self._call_count['initialize_episode_mjcf'] += 1
+
+ def after_compile(self, physics, random_state):
+ """Implements `after_compile` Composer callback."""
+ if self._has_super:
+ super().after_compile(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf')
+ self.assertHooksNotCalled('initialize_episode',
+ 'before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(physics,
+ self._call_count['before_substep'])
+ self._call_count['after_compile'] += 1
+
+ def initialize_episode(self, physics, random_state):
+ """Implements `initialize_episode` Composer callback."""
+ if self._has_super:
+ super().initialize_episode(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile')
+ self.assertHooksNotCalled('before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(physics,
+ self._call_count['before_substep'])
+ self._call_count['initialize_episode'] += 1
+
+ def before_step(self, physics, *args):
+ """Implements `before_step` Composer callback."""
+ if self._has_super:
+ super().before_step(physics, *args)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # `before_step` is only called in between complete control steps.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'])
+
+ # Complete control steps imply complete physics steps.
+ self.assertEqual(
+ self._call_count['after_substep'], self._call_count['before_substep'])
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(physics,
+ self._call_count['before_substep'])
+
+ self._call_count['before_step'] += 1
+
+ def before_substep(self, physics, *args):
+ """Implements `before_substep` Composer callback."""
+ if self._has_super:
+ super().before_substep(physics, *args)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # We are inside a partial control step, so `after_step` should lag behind.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'] - 1)
+
+ # `before_substep` is only called in between complete physics steps.
+ self.assertEqual(
+ self._call_count['after_substep'], self._call_count['before_substep'])
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(
+ physics, self._call_count['before_substep'])
+
+ self._call_count['before_substep'] += 1
+
+ def after_substep(self, physics, random_state):
+ """Implements `after_substep` Composer callback."""
+ if self._has_super:
+ super().after_substep(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # We are inside a partial control step, so `after_step` should lag behind.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'] - 1)
+
+ # We are inside a partial physics step, so `after_substep` should be behind.
+ self.assertEqual(self._call_count['after_substep'],
+ self._call_count['before_substep'] - 1)
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(
+ physics, self._call_count['before_substep'])
+
+ self._call_count['after_substep'] += 1
+
+ def after_step(self, physics, random_state):
+ """Implements `after_step` Composer callback."""
+ if self._has_super:
+ super().after_step(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # We are inside a partial control step, so `after_step` should lag behind.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'] - 1)
+
+ # `after_step` is only called in between complete physics steps.
+ self.assertEqual(
+ self._call_count['after_substep'], self._call_count['before_substep'])
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(
+ physics, self._call_count['before_substep'])
+
+ # Check that the number of physics steps is consistent with control steps.
+ self.assertEqual(
+ self._call_count['before_substep'],
+ self._call_count['before_step'] * self._physics_steps_per_control_step)
+
+ self._call_count['after_step'] += 1
+
+
+class TrackedEntity(HooksTracker, composer.Entity):
+ """A `composer.Entity` that tracks call order of callbacks."""
+
+ def _build(self, name):
+ self._mjcf_root = mjcf.RootElement(model=name)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def name(self):
+ return self._mjcf_root.model
+
+
+class TrackedTask(HooksTracker, composer.NullTask):
+ """A `composer.Task` that tracks call order of callbacks."""
+
+ def __init__(self, physics_timestep, control_timestep, *args, **kwargs):
+ super().__init__(
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep,
+ *args,
+ **kwargs)
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+ add_bodies_and_actuators(self.root_entity.mjcf_model, num_actuators=4)
+
+
+class HooksTestMixin:
+ """A mixin for an `absltest.TestCase` to track call order of callbacks."""
+
+ def setUp(self):
+ """Sets up the test case."""
+ super().setUp()
+
+ self.num_episodes = 5
+ self.steps_per_episode = 100
+
+ self.control_timestep = 0.05
+ self.physics_timestep = 0.002
+
+ self.extra_hooks = HooksTracker(physics_timestep=self.physics_timestep,
+ control_timestep=self.control_timestep,
+ test_case=self)
+
+ self.entities = []
+ for i in range(9):
+ self.entities.append(TrackedEntity(name='entity_{}'.format(i),
+ physics_timestep=self.physics_timestep,
+ control_timestep=self.control_timestep,
+ test_case=self))
+
+ ########################################
+ # Make the following entity hierarchy #
+ # 0 #
+ # 1 2 3 #
+ # 4 5 6 7 #
+ # 8 #
+ ########################################
+
+ self.entities[4].attach(self.entities[8])
+ self.entities[1].attach(self.entities[4])
+ self.entities[1].attach(self.entities[5])
+ self.entities[0].attach(self.entities[1])
+
+ self.entities[2].attach(self.entities[6])
+ self.entities[2].attach(self.entities[7])
+ self.entities[0].attach(self.entities[2])
+
+ self.entities[0].attach(self.entities[3])
+
+ self.task = TrackedTask(root_entity=self.entities[0],
+ physics_timestep=self.physics_timestep,
+ control_timestep=self.control_timestep,
+ test_case=self)
+
+ @contextlib.contextmanager
+ def track_episode(self):
+ tracked_objects = [self.task, self.extra_hooks] + self.entities
+ for obj in tracked_objects:
+ obj.reset_call_counts()
+ obj.tracked = True
+ yield
+ for obj in tracked_objects:
+ obj.assertCompleteEpisode(self.steps_per_episode)
+ obj.tracked = False
diff --git a/dm_control/composer/initializer.py b/dm_control/composer/initializer.py
new file mode 100644
index 00000000..9db64a97
--- /dev/null
+++ b/dm_control/composer/initializer.py
@@ -0,0 +1,26 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining the abstract initializer."""
+
+import abc
+
+
+class Initializer(metaclass=abc.ABCMeta):
+ """The abstract base class for an initializer."""
+
+ @abc.abstractmethod
+ def __call__(self, physics, random_state):
+ raise NotImplementedError
diff --git a/dm_control/composer/initializers/__init__.py b/dm_control/composer/initializers/__init__.py
new file mode 100644
index 00000000..faeac98c
--- /dev/null
+++ b/dm_control/composer/initializers/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tools for initializing the states of Composer environments."""
+
+from dm_control.composer.initializers.prop_initializer import PropPlacer
+from dm_control.composer.initializers.tcp_initializer import ToolCenterPointInitializer
diff --git a/dm_control/composer/initializers/prop_initializer.py b/dm_control/composer/initializers/prop_initializer.py
new file mode 100644
index 00000000..b809b31a
--- /dev/null
+++ b/dm_control/composer/initializers/prop_initializer.py
@@ -0,0 +1,285 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An initializer that places props at various poses."""
+
+from absl import logging
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import variation
+from dm_control.composer.initializers import utils
+from dm_control.composer.variation import rotations
+from dm_control.rl import control
+import numpy as np
+
+
+# Absolute velocity threshold for a prop joint to be considered settled.
+_SETTLE_QVEL_TOL = 1e-3
+# Absolute acceleration threshold for a prop joint to be considered settled.
+_SETTLE_QACC_TOL = 1e-2
+
+_REJECTION_SAMPLING_FAILED = '\n'.join([
+ 'Failed to find a non-colliding pose for prop {model_name!r} within ' # pylint: disable=implicit-str-concat
+ '{max_attempts} attempts.',
+ 'You may be able to avoid this error by:',
+ '1. Sampling from a broader distribution over positions and/or quaternions',
+ '2. Increasing `max_attempts_per_prop`',
+ '3. Disabling collision detection by setting `ignore_collisions=False`'])
+
+
+_SETTLING_PHYSICS_FAILED = '\n'.join([
+ 'Failed to settle physics after {max_attempts} attempts of ' # pylint: disable=implicit-str-concat
+ '{max_time} seconds.',
+ 'Last residual velocity={max_qvel} and acceleration={max_qacc}.',
+ 'This suggests your dynamics are unstable. Consider:',
+ '\t1. Increasing `max_settle_physics_attempts`',
+ '\t2. Increasing `max_settle_physics_time`',
+ '\t3. Tuning your contact parameters or initial pose distributions.'])
+
+
+class PropPlacer(composer.Initializer):
+ """An initializer that places props at various positions and orientations."""
+
+ def __init__(self,
+ props,
+ position,
+ quaternion=rotations.IDENTITY_QUATERNION,
+ ignore_collisions=False,
+ max_qvel_tol=_SETTLE_QVEL_TOL,
+ max_qacc_tol=_SETTLE_QACC_TOL,
+ max_attempts_per_prop=20,
+ settle_physics=False,
+ min_settle_physics_time=0.,
+ max_settle_physics_time=2.,
+ max_settle_physics_attempts=1,
+ raise_exception_on_settle_failure=False):
+ """Initializes this PropPlacer.
+
+ Args:
+ props: A sequence of `composer.Entity` instances representing props.
+ position: A single fixed Cartesian position, or a `composer.Variation`
+ object that generates Cartesian positions. If a fixed sequence of
+ positions for multiple props is desired, use
+ `variation.deterministic.Sequence`.
+ quaternion: (optional) A single fixed unit quaternion, or a
+ `Variation` object that generates unit quaternions. If a fixed
+ sequence of quaternions for multiple props is desired, use
+ `variation.deterministic.Sequence`.
+ ignore_collisions: (optional) If True, ignore collisions between props,
+ i.e. do not run rejection sampling.
+ max_qvel_tol: Maximum post-initialization joint velocity for props. If
+ `settle_physics=True`, the simulation will be run until all prop joint
+ velocities are less than this threshold.
+ max_qacc_tol: Maximum post-initialization joint acceleration for props. If
+ `settle_physics=True`, the simulation will be run until all prop joint
+ velocities are less than this threshold.
+ max_attempts_per_prop: The maximum number of rejection sampling attempts
+ per prop. If a non-colliding pose cannot be found before this limit is
+ reached, an `EpisodeInitializationError` will be raised.
+ settle_physics: (optional) If True, the physics simulation will be
+ advanced for a few steps to allow the prop positions to settle.
+ min_settle_physics_time: (optional) When `settle_physics` is True, lower
+ bound on time (in seconds) the physics simulation is advanced.
+ max_settle_physics_time: (optional) When `settle_physics` is True, upper
+ bound on time (in seconds) the physics simulation is advanced.
+ max_settle_physics_attempts: (optional) When `settle_physics` is True, the
+ number of attempts at sampling overall scene pose and settling.
+ raise_exception_on_settle_failure: If True, raises an exception if
+ settling physics is unsuccessful.
+ """
+ super().__init__()
+ self._props = props
+ self._prop_joints = []
+ for prop in props:
+ freejoint = mjcf.get_frame_freejoint(prop.mjcf_model)
+ if freejoint is not None:
+ self._prop_joints.append(freejoint)
+ self._prop_joints.extend(prop.mjcf_model.find_all('joint'))
+ self._position = position
+ self._quaternion = quaternion
+ self._ignore_collisions = ignore_collisions
+ self._max_attempts_per_prop = max_attempts_per_prop
+ self._settle_physics = settle_physics
+ self._max_qvel_tol = max_qvel_tol
+ self._max_qacc_tol = max_qacc_tol
+ self._min_settle_physics_time = min_settle_physics_time
+ self._max_settle_physics_time = max_settle_physics_time
+ self._max_settle_physics_attempts = max_settle_physics_attempts
+ self._raise_exception_on_settle_failure = raise_exception_on_settle_failure
+
+ if max_settle_physics_attempts < 1:
+ raise ValueError('max_settle_physics_attempts should be greater than '
+ 'zero to have any effect, but is '
+ f'{max_settle_physics_attempts}')
+
+ def _has_collisions_with_prop(self, physics, prop):
+ prop_geom_ids = physics.bind(prop.mjcf_model.find_all('geom')).element_id
+ contacts = physics.data.contact
+ for contact in contacts:
+ # Ignore contacts with positive distances (i.e. not actually touching).
+ if contact.dist <= 0 and (contact.geom1 in prop_geom_ids or
+ contact.geom2 in prop_geom_ids):
+ return True
+
+ def _disable_and_cache_contact_parameters(self, physics, props):
+ cached_contact_params = {}
+ for prop in props:
+ geoms = prop.mjcf_model.find_all('geom')
+ param_list = []
+ for geom in geoms:
+ bound_geom = physics.bind(geom)
+ param_list.append((bound_geom.contype, bound_geom.conaffinity))
+ bound_geom.contype = 0
+ bound_geom.conaffinity = 0
+ cached_contact_params[prop] = param_list
+ return cached_contact_params
+
+ def _restore_contact_parameters(self, physics, prop, cached_contact_params):
+ geoms = prop.mjcf_model.find_all('geom')
+ param_list = cached_contact_params[prop]
+ for i, geom in enumerate(geoms):
+ contype, conaffinity = param_list[i]
+ bound_geom = physics.bind(geom)
+ bound_geom.contype = contype
+ bound_geom.conaffinity = conaffinity
+
+ def __call__(self, physics, random_state, ignore_contacts_with_entities=None):
+ """Sets initial prop poses.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ random_state: a `np.random.RandomState` instance.
+ ignore_contacts_with_entities: a list of `composer.Entity` instances
+ to ignore when detecting collisions. This can be used to ignore props
+ that are not being placed by this initializer, but are known to be
+ colliding in the current state of the simulation (for example if they
+ are going to be placed by a different initializer that will be called
+ subsequently).
+
+ Raises:
+ EpisodeInitializationError: If `ignore_collisions == False` and a
+ non-colliding prop pose could not be found within
+ `max_attempts_per_prop`.
+ """
+ if ignore_contacts_with_entities is None:
+ ignore_contacts_with_entities = []
+ # Temporarily disable contacts for all geoms that belong to props which
+ # haven't yet been placed in order to free up space in the contact buffer.
+ cached_contact_params = self._disable_and_cache_contact_parameters(
+ physics, self._props + ignore_contacts_with_entities)
+
+ try:
+ physics.forward()
+ except control.PhysicsError as cause:
+ effect = control.PhysicsError(
+ 'Despite disabling contact for all props in this initializer, '
+ '`physics.forward()` resulted in a `PhysicsError`')
+ raise effect from cause
+
+ def place_props():
+ for prop in self._props:
+ # Restore the original contact parameters for all geoms in the prop
+ # we're about to place, so that we can detect if the new pose results in
+ # collisions.
+ self._restore_contact_parameters(physics, prop, cached_contact_params)
+
+ success = False
+ initial_position, initial_quaternion = prop.get_pose(physics)
+ next_position, next_quaternion = initial_position, initial_quaternion
+ for _ in range(self._max_attempts_per_prop):
+ next_position = variation.evaluate(self._position,
+ initial_value=initial_position,
+ current_value=next_position,
+ random_state=random_state)
+ next_quaternion = variation.evaluate(self._quaternion,
+ initial_value=initial_quaternion,
+ current_value=next_quaternion,
+ random_state=random_state)
+ prop.set_pose(physics, next_position, next_quaternion)
+ try:
+ # If this pose results in collisions then there's a chance we'll
+ # encounter a PhysicsError error here due to a full contact buffer,
+ # in which case reject this pose and sample another.
+ physics.forward()
+ except control.PhysicsError:
+ continue
+
+ if (self._ignore_collisions
+ or not self._has_collisions_with_prop(physics, prop)):
+ success = True
+ break
+
+ if not success:
+ raise composer.EpisodeInitializationError(
+ _REJECTION_SAMPLING_FAILED.format(
+ model_name=prop.mjcf_model.model,
+ max_attempts=self._max_attempts_per_prop,
+ )
+ )
+
+ for prop in ignore_contacts_with_entities:
+ self._restore_contact_parameters(physics, prop, cached_contact_params)
+
+ # Place the props and settle the physics. If settling was requested and it
+ # it fails, re-place the props.
+ def place_and_settle():
+ for _ in range(self._max_settle_physics_attempts):
+ place_props()
+
+ # Step physics and check prop states.
+ original_time = physics.data.time
+ try:
+ props_isolator = utils.JointStaticIsolator(physics, self._prop_joints)
+ prop_joints_mj = physics.bind(self._prop_joints)
+ while (
+ physics.data.time - original_time < self._max_settle_physics_time
+ ):
+ with props_isolator:
+ physics.step()
+ max_qvel = np.max(np.abs(prop_joints_mj.qvel))
+ max_qacc = np.max(np.abs(prop_joints_mj.qacc))
+ if (max_qvel < self._max_qvel_tol) and (
+ max_qacc < self._max_qacc_tol) and (
+ physics.data.time - original_time
+ ) > self._min_settle_physics_time:
+ return True
+ finally:
+ physics.data.time = original_time
+
+ if self._raise_exception_on_settle_failure:
+ raise composer.EpisodeInitializationError(
+ _SETTLING_PHYSICS_FAILED.format(
+ max_attempts=self._max_settle_physics_attempts,
+ max_time=self._max_settle_physics_time,
+ max_qvel=max_qvel,
+ max_qacc=max_qacc,
+ ))
+ else:
+ log_str = _SETTLING_PHYSICS_FAILED.format(
+ max_attempts='%s',
+ max_time='%s',
+ max_qvel='%s',
+ max_qacc='%s',
+ )
+ logging.warning(log_str, self._max_settle_physics_attempts,
+ self._max_settle_physics_time, max_qvel, max_qacc)
+
+ return False
+
+ if self._settle_physics:
+ place_and_settle()
+ else:
+ place_props()
diff --git a/dm_control/composer/initializers/prop_initializer_test.py b/dm_control/composer/initializers/prop_initializer_test.py
new file mode 100644
index 00000000..dc27d5ca
--- /dev/null
+++ b/dm_control/composer/initializers/prop_initializer_test.py
@@ -0,0 +1,238 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.initializers import prop_initializer
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+import numpy as np
+
+
+class _SequentialChoice(distributions.Distribution):
+ """Helper class to return samples in order for deterministic testing."""
+ __slots__ = ()
+
+ def __init__(self, choices, single_sample=False):
+ super().__init__(choices, single_sample=single_sample)
+ self._idx = 0
+
+ def _callable(self, random_state):
+ def next_item(*args, **kwargs):
+ del args, kwargs # Unused.
+ result = self._args[0][self._idx]
+ self._idx = (self._idx + 1) % len(self._args[0])
+ return result
+
+ return next_item
+
+
+def _make_spheres(num_spheres, radius, nconmax):
+ spheres = []
+ arena = composer.Arena()
+ arena.mjcf_model.worldbody.add('geom', type='plane', size=[1, 1, 0.1],
+ pos=[0., 0., -2 * radius], name='ground')
+ for i in range(num_spheres):
+ sphere = props.Primitive(
+ geom_type='sphere', size=[radius], name='sphere_{}'.format(i))
+ arena.add_free_entity(sphere)
+ spheres.append(sphere)
+ arena.mjcf_model.size.nconmax = nconmax
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ return physics, spheres
+
+
+class PropPlacerTest(parameterized.TestCase):
+ """Tests for PropPlacer."""
+
+ def assertNoContactsInvolvingEntities(self, physics, entities):
+ all_colliding_geoms = set()
+ for contact in physics.data.contact:
+ all_colliding_geoms.add(contact.geom1)
+ all_colliding_geoms.add(contact.geom2)
+ for entity in entities:
+ entity_geoms = physics.bind(entity.mjcf_model.find_all('geom')).element_id
+ colliding_entity_geoms = all_colliding_geoms.intersection(entity_geoms)
+ if colliding_entity_geoms:
+ names = ', '.join(
+ physics.model.id2name(i, 'geom') for i in colliding_entity_geoms)
+ self.fail('Entity {} has colliding geoms: {}'
+ .format(entity.mjcf_model.model, names))
+
+ def assertPositionsWithinBounds(self, physics, entities, lower, upper):
+ for entity in entities:
+ position, _ = entity.get_pose(physics)
+ if np.any(position < lower) or np.any(position > upper):
+ self.fail('Entity {} is out of bounds: position={}, bounds={}'
+ .format(entity.mjcf_model.model, position, (lower, upper)))
+
+ def test_sample_non_colliding_positions(self):
+ halfwidth = 0.05
+ radius = halfwidth / 4.
+ offset = np.array([0, 0, halfwidth + radius*1.1])
+ lower = -np.full(3, halfwidth) + offset
+ upper = np.full(3, halfwidth) + offset
+ position_variation = distributions.Uniform(lower, upper)
+ physics, spheres = _make_spheres(num_spheres=8, radius=radius, nconmax=1000)
+ prop_placer = prop_initializer.PropPlacer(
+ props=spheres,
+ position=position_variation,
+ ignore_collisions=False,
+ settle_physics=False)
+ prop_placer(physics, random_state=np.random.RandomState(0))
+ self.assertNoContactsInvolvingEntities(physics, spheres)
+ self.assertPositionsWithinBounds(physics, spheres, lower, upper)
+
+ def test_rejection_sampling_failure(self):
+ max_attempts_per_prop = 2
+ fixed_position = (0, 0, 0.1) # Guaranteed to always have collisions.
+ physics, spheres = _make_spheres(num_spheres=2, radius=0.01, nconmax=1000)
+ prop_placer = prop_initializer.PropPlacer(
+ props=spheres,
+ position=fixed_position,
+ ignore_collisions=False,
+ max_attempts_per_prop=max_attempts_per_prop)
+ expected_message = prop_initializer._REJECTION_SAMPLING_FAILED.format(
+ model_name=spheres[1].mjcf_model.model, # Props are placed in order.
+ max_attempts=max_attempts_per_prop)
+ with self.assertRaisesWithLiteralMatch(
+ composer.EpisodeInitializationError, expected_message
+ ):
+ prop_placer(physics, random_state=np.random.RandomState(0))
+
+ def test_ignore_contacts_with_entities(self):
+ physics, spheres = _make_spheres(num_spheres=2, radius=0.01, nconmax=1000)
+
+ # Target position of both spheres (non-colliding).
+ fixed_positions = [(0, 0, 0.1), (0, 0.1, 0.1)]
+
+ # Placer that initializes both spheres to (0, 0, 0.1), ignoring contacts.
+ prop_placer_init = prop_initializer.PropPlacer(
+ props=spheres,
+ position=fixed_positions[0],
+ ignore_collisions=True,
+ max_attempts_per_prop=1)
+
+ # Sequence of placers that will move the spheres to their target positions.
+ prop_placer_seq = []
+ for prop, target_position in zip(spheres, fixed_positions):
+ placer = prop_initializer.PropPlacer(
+ props=[prop],
+ position=target_position,
+ ignore_collisions=False,
+ max_attempts_per_prop=1)
+ prop_placer_seq.append(placer)
+
+ # We expect the first placer in the sequence to fail without
+ # `ignore_contacts_with_entities` because the second sphere is already at
+ # the same location.
+ prop_placer_init(physics, random_state=np.random.RandomState(0))
+ expected_message = prop_initializer._REJECTION_SAMPLING_FAILED.format(
+ model_name=spheres[0].mjcf_model.model, max_attempts=1)
+ with self.assertRaisesWithLiteralMatch(
+ composer.EpisodeInitializationError, expected_message
+ ):
+ prop_placer_seq[0](physics, random_state=np.random.RandomState(0))
+
+ # Placing the first sphere should succeed if we ignore contacts involving
+ # the second sphere.
+ prop_placer_init(physics, random_state=np.random.RandomState(0))
+ prop_placer_seq[0](physics, random_state=np.random.RandomState(0),
+ ignore_contacts_with_entities=[spheres[1]])
+ # Now place the second sphere with all collisions active.
+ prop_placer_seq[1](physics, random_state=np.random.RandomState(0),
+ ignore_contacts_with_entities=None)
+ self.assertNoContactsInvolvingEntities(physics, spheres)
+
+ @parameterized.parameters([False, True])
+ def test_settle_physics(self, settle_physics):
+ radius = 0.1
+ physics, spheres = _make_spheres(num_spheres=2, radius=radius, nconmax=1)
+ physics_start_time = 1337.0
+ physics.data.time = physics_start_time
+
+ # Only place the first sphere.
+ prop_placer = prop_initializer.PropPlacer(
+ props=spheres[:1],
+ position=np.array([2.01 * radius, 0., 0.]),
+ settle_physics=settle_physics)
+ prop_placer(physics, random_state=np.random.RandomState(0))
+
+ first_position, first_quaternion = spheres[0].get_pose(physics)
+ del first_quaternion # Unused.
+
+ # If we allowed the physics to settle then the first sphere should be
+ # resting on the ground, otherwise it should be at the target height.
+ expected_first_z_pos = -radius if settle_physics else 0.
+ self.assertAlmostEqual(first_position[2], expected_first_z_pos, places=3)
+
+ second_position, second_quaternion = spheres[1].get_pose(physics)
+ del second_quaternion # Unused.
+
+ # The sphere that we were not placing should not have moved.
+ self.assertEqual(second_position[2], 0.0)
+ self.assertEqual(
+ physics.data.time, physics_start_time, 'Physics time should be reset.'
+ )
+
+ @parameterized.parameters([0, 1, 2, 3])
+ def test_settle_physics_multiple_attempts(self, max_settle_physics_attempts):
+ # Tests the multiple-reset mechanism for `settle_physics`.
+ # Rather than testing the mechanic itself, which is tested above, we instead
+ # test that the mechanism correctly makes several attempts when it fails
+ # to settle. We force it to fail by making the settling time short, and
+ # test that the position is repeatedly called using a deterministic
+ # sequential pose distribution.
+
+ radius = 0.1
+ physics, spheres = _make_spheres(num_spheres=1, radius=radius, nconmax=1)
+
+ # Generate sequence of positions that will be sampled in order.
+ positions = [
+ np.array([2.01 * radius, 1., 0.]),
+ np.array([2.01 * radius, 2., 0.]),
+ np.array([2.01 * radius, 3., 0.]),
+ ]
+ positions_dist = _SequentialChoice(positions)
+
+ def build_placer():
+ return prop_initializer.PropPlacer(
+ props=spheres[:1],
+ position=positions_dist,
+ settle_physics=True,
+ max_settle_physics_time=1e-6, # To ensure that settling FAILS.
+ max_settle_physics_attempts=max_settle_physics_attempts)
+
+ if max_settle_physics_attempts == 0:
+ with self.assertRaises(ValueError):
+ build_placer()
+ else:
+ prop_placer = build_placer()
+
+ prop_placer(physics, random_state=np.random.RandomState(0))
+
+ first_position, first_quaternion = spheres[0].get_pose(physics)
+ del first_quaternion # Unused.
+
+ # If we allowed the physics to settle then the first sphere should be
+ # resting on the ground, otherwise it should be at the target height.
+ expected_first_y_pos = max_settle_physics_attempts
+ self.assertAlmostEqual(first_position[1], expected_first_y_pos, places=3)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/initializers/tcp_initializer.py b/dm_control/composer/initializers/tcp_initializer.py
new file mode 100644
index 00000000..0b8e5787
--- /dev/null
+++ b/dm_control/composer/initializers/tcp_initializer.py
@@ -0,0 +1,170 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An initializer that sets the pose of a hand's tool center point."""
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import variation
+from dm_control.entities.manipulators import base
+
+
+_REJECTION_SAMPLING_FAILED = (
+ 'Failed to find a valid initial configuration for the robot after '
+ '{max_rejection_samples} TCP poses sampled and up to {max_ik_attempts} '
+ 'initial joint configurations per pose.')
+
+
+class ToolCenterPointInitializer(composer.Initializer):
+ """An initializer that sets the position of a hand's tool center point.
+
+ This initializer calls the RobotArm's internal method to try and set the
+ hand's TCP to a randomized Cartesian position within the specified bound.
+ By default the initializer performs rejection sampling in order to avoid
+ poses that result in "relevant collisions", which are defined as:
+
+ * Collisions between links of the robot arm
+ * Collisions between the arm and the hand
+ * Collisions between either the arm or hand and an external body without a
+ free joint
+ """
+
+ def __init__(self,
+ hand,
+ arm,
+ position,
+ quaternion=base.DOWN_QUATERNION,
+ ignore_collisions=False,
+ max_ik_attempts=10,
+ max_rejection_samples=10):
+ """Initializes this ToolCenterPointInitializer.
+
+ Args:
+ hand: Either a `base.RobotHand` instance or None, in which case
+ `arm.wrist_site` is used as the TCP site in place of
+ `hand.tool_center_point`.
+ arm: A `base.RobotArm` instance.
+ position: A single fixed Cartesian position, or a `Variation`
+ object that generates Cartesian positions. If a fixed sequence of
+ positions for multiple props is desired, use
+ `variation.deterministic.Sequence`.
+ quaternion: (optional) A single fixed unit quaternion, or a
+ `composer.Variation` object that generates unit quaternions. If a fixed
+ sequence of quaternions for ultiple props is desired, use
+ `variation.deterministic.Sequence`.
+ ignore_collisions: (optional) If True all collisions are ignored, i.e.
+ rejection sampling is disabled.
+ max_ik_attempts: (optional) Maximum number of attempts for the inverse
+ kinematics solver to find a solution satisfying `target_pos` and
+ `target_quat`. These are attempts per rejection sample. If more than
+ one attempt is performed, the joint configuration will be randomized
+ before the second trial. To avoid randomizing joint positions, set this
+ parameter to 1.
+ max_rejection_samples (optional): Maximum number of TCP target poses to
+ sample while attempting to find a non-colliding configuration. For each
+ sampled pose, up to `max_ik_attempts` may be performed in order to find
+ an IK solution satisfying this pose.
+ """
+ super().__init__()
+ self._arm = arm
+ self._hand = hand
+ self._position = position
+ self._quaternion = quaternion
+ self._ignore_collisions = ignore_collisions
+ self._max_ik_attempts = max_ik_attempts
+ self._max_rejection_samples = max_rejection_samples
+
+ def _has_relevant_collisions(self, physics):
+ mjcf_root = self._arm.mjcf_model.root_model
+ all_geoms = mjcf_root.find_all('geom')
+ free_body_geoms = set()
+ for body in mjcf_root.worldbody.get_children('body'):
+ if mjcf.get_freejoint(body):
+ free_body_geoms.update(body.find_all('geom'))
+
+ arm_model = self._arm.mjcf_model
+ hand_model = None
+ if self._hand is not None:
+ hand_model = self._hand.mjcf_model
+
+ def is_robot(geom):
+ return geom.root is arm_model or geom.root is hand_model
+
+ def is_external_body_without_freejoint(geom):
+ return not (is_robot(geom) or geom in free_body_geoms)
+
+ for contact in physics.data.contact:
+ geom_1 = all_geoms[contact.geom1]
+ geom_2 = all_geoms[contact.geom2]
+ if contact.dist > 0:
+ # Ignore "contacts" with positive distance (i.e. not actually touching).
+ continue
+ if (
+ # Include arm-arm and arm-hand self-collisions (but not hand-hand).
+ (geom_1.root is arm_model and geom_2.root is arm_model) or
+ (geom_1.root is arm_model and geom_2.root is hand_model) or
+ (geom_1.root is hand_model and geom_2.root is arm_model) or
+ # Include collisions between the arm or hand and an external body
+ # provided that the external body does not have a freejoint.
+ (is_robot(geom_1) and is_external_body_without_freejoint(geom_2)) or
+ (is_external_body_without_freejoint(geom_1) and is_robot(geom_2))):
+ return True
+ return False
+
+ def __call__(self, physics, random_state):
+ """Sets initial tool center point pose via inverse kinematics.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ random_state: An `np.random.RandomState` instance.
+
+ Raises:
+ composer.EpisodeInitializationError: If a collision-free pose could not be
+ found within `max_ik_attempts`.
+ """
+ if self._hand is not None:
+ target_site = self._hand.tool_center_point
+ else:
+ target_site = self._arm.wrist_site
+
+ initial_qpos = physics.bind(self._arm.joints).qpos.copy()
+
+ for _ in range(self._max_rejection_samples):
+ target_pos = variation.evaluate(self._position,
+ random_state=random_state)
+ target_quat = variation.evaluate(self._quaternion,
+ random_state=random_state)
+ success = self._arm.set_site_to_xpos(
+ physics=physics, random_state=random_state, site=target_site,
+ target_pos=target_pos, target_quat=target_quat,
+ max_ik_attempts=self._max_ik_attempts)
+
+ if success:
+ physics.forward() # Recalculate contacts.
+ if (self._ignore_collisions
+ or not self._has_relevant_collisions(physics)):
+ return
+
+ # If IK failed to find a solution for this target pose, or if the solution
+ # resulted in contacts, then reset the arm joints to their original
+ # positions and try again with a new target.
+ physics.bind(self._arm.joints).qpos = initial_qpos
+
+ raise composer.EpisodeInitializationError(
+ _REJECTION_SAMPLING_FAILED.format(
+ max_rejection_samples=self._max_rejection_samples,
+ max_ik_attempts=self._max_ik_attempts,
+ )
+ )
diff --git a/dm_control/composer/initializers/tcp_initializer_test.py b/dm_control/composer/initializers/tcp_initializer_test.py
new file mode 100644
index 00000000..002ac8f6
--- /dev/null
+++ b/dm_control/composer/initializers/tcp_initializer_test.py
@@ -0,0 +1,225 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import functools
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.initializers import tcp_initializer
+from dm_control.entities import props
+from dm_control.entities.manipulators import kinova
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+
+class TcpInitializerTest(parameterized.TestCase):
+
+ def make_model(self, with_hand=True):
+ arm = kinova.JacoArm()
+ arena = composer.Arena()
+ arena.attach(arm)
+ if with_hand:
+ hand = kinova.JacoHand()
+ arm.attach(hand)
+ else:
+ hand = None
+ return arena, arm, hand
+
+ def assertTargetPoseAchieved(self, frame_binding, target_pos, target_quat):
+ np.testing.assert_array_almost_equal(target_pos, frame_binding.xpos)
+ target_xmat = np.empty(9, np.double)
+ mjlib.mju_quat2Mat(target_xmat, target_quat / np.linalg.norm(target_quat))
+ np.testing.assert_array_almost_equal(target_xmat, frame_binding.xmat)
+
+ def assertEntitiesInContact(self, physics, first, second):
+ first_geom_ids = physics.bind(
+ first.mjcf_model.find_all('geom')).element_id
+ second_geom_ids = physics.bind(
+ second.mjcf_model.find_all('geom')).element_id
+ contact = physics.data.contact
+ first_to_second = (np.isin(contact.geom1, first_geom_ids).ravel() &
+ np.isin(contact.geom2, second_geom_ids).ravel())
+ second_to_first = (np.isin(contact.geom1, second_geom_ids).ravel() &
+ np.isin(contact.geom2, first_geom_ids).ravel())
+ touching = contact.dist <= 0
+ valid_contact = touching & (first_to_second | second_to_first)
+ self.assertTrue(np.any(valid_contact), msg='Entities are not in contact.')
+
+ @parameterized.parameters([
+ dict(target_pos=np.array([0.1, 0.2, 0.3]),
+ target_quat=np.array([0., 1., 1., 0.]),
+ with_hand=True),
+ dict(target_pos=np.array([0., -0.1, 0.5]),
+ target_quat=np.array([1., 1., 0., 0.]),
+ with_hand=False),
+ ])
+ def test_initialize_to_fixed_pose(self, target_pos, target_quat, with_hand):
+ arena, arm, hand = self.make_model(with_hand=with_hand)
+ initializer = tcp_initializer.ToolCenterPointInitializer(
+ hand=hand, arm=arm, position=target_pos, quaternion=target_quat)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ initializer(physics=physics, random_state=np.random.RandomState(0))
+ site = hand.tool_center_point if with_hand else arm.wrist_site
+ self.assertTargetPoseAchieved(physics.bind(site), target_pos, target_quat)
+
+ def test_exception_if_hand_colliding_with_fixed_body(self):
+ arena, arm, hand = self.make_model()
+ target_pos = np.array([0.1, 0.2, 0.3])
+ target_quat = np.array([0., 1., 1., 0.])
+ max_rejection_samples = 10
+ max_ik_attempts = 5
+
+ # Place a fixed obstacle at the target location so that the TCP can't reach
+ # the target without colliding with it.
+ obstacle = props.Primitive(geom_type='sphere', size=[0.3])
+ attachment_frame = arena.attach(obstacle)
+ attachment_frame.pos = target_pos
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ make_initializer = functools.partial(
+ tcp_initializer.ToolCenterPointInitializer,
+ hand=hand,
+ arm=arm,
+ position=target_pos,
+ quaternion=target_quat,
+ max_ik_attempts=max_ik_attempts,
+ max_rejection_samples=max_rejection_samples)
+
+ initializer = make_initializer()
+ with self.assertRaisesWithLiteralMatch(
+ composer.EpisodeInitializationError,
+ tcp_initializer._REJECTION_SAMPLING_FAILED.format(
+ max_rejection_samples=max_rejection_samples,
+ max_ik_attempts=max_ik_attempts)):
+ initializer(physics=physics, random_state=np.random.RandomState(0))
+
+ # The initializer should succeed if we ignore collisions.
+ initializer_ignore_collisions = make_initializer(ignore_collisions=True)
+ initializer_ignore_collisions(physics=physics,
+ random_state=np.random.RandomState(0))
+ self.assertTargetPoseAchieved(
+ physics.bind(hand.tool_center_point), target_pos, target_quat)
+
+ # Confirm that the obstacle and the hand are in contact.
+ self.assertEntitiesInContact(physics, hand, obstacle)
+
+ @parameterized.named_parameters([
+ dict(testcase_name='between_arm_and_arm', with_hand=False),
+ dict(testcase_name='between_arm_and_hand', with_hand=True),
+ ])
+ def test_exception_if_self_collision(self, with_hand):
+ arena, arm, hand = self.make_model(with_hand=with_hand)
+ # This pose places the wrist or hand partially inside the base of the arm.
+ target_pos = np.array([0., 0.1, 0.1])
+ target_quat = np.array([-1., 1., 0., 0.])
+ max_rejection_samples = 10
+ max_ik_attempts = 5
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ make_initializer = functools.partial(
+ tcp_initializer.ToolCenterPointInitializer,
+ hand=hand,
+ arm=arm,
+ position=target_pos,
+ quaternion=target_quat,
+ max_ik_attempts=max_ik_attempts,
+ max_rejection_samples=max_rejection_samples)
+
+ initializer = make_initializer()
+ with self.assertRaisesWithLiteralMatch(
+ composer.EpisodeInitializationError,
+ tcp_initializer._REJECTION_SAMPLING_FAILED.format(
+ max_rejection_samples=max_rejection_samples,
+ max_ik_attempts=max_ik_attempts)):
+ initializer(physics=physics, random_state=np.random.RandomState(0))
+
+ # The initializer should succeed if we ignore collisions.
+ initializer_ignore_collisions = make_initializer(ignore_collisions=True)
+ initializer_ignore_collisions(physics=physics,
+ random_state=np.random.RandomState(0))
+ site = hand.tool_center_point if with_hand else arm.wrist_site
+ self.assertTargetPoseAchieved(
+ physics.bind(site), target_pos, target_quat)
+
+ # Confirm that there is self-collision.
+ self.assertEntitiesInContact(physics, arm, hand if with_hand else arm)
+
+ def test_ignore_robot_collision_with_free_body(self):
+ arena, arm, hand = self.make_model()
+ target_pos = np.array([0.1, 0.2, 0.3])
+ target_quat = np.array([0., 1., 1., 0.])
+
+ # The obstacle is still placed at the target location, but this time it has
+ # a freejoint and is held in place by a weld constraint.
+ obstacle = props.Primitive(geom_type='sphere', size=[0.3], pos=target_pos)
+ attachment_frame = arena.add_free_entity(obstacle)
+ attachment_frame.pos = target_pos
+ arena.mjcf_model.equality.add(
+ 'weld', body1=attachment_frame,
+ relpose=np.hstack([target_pos, [1., 0., 0., 0.]]))
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ initializer = tcp_initializer.ToolCenterPointInitializer(
+ hand=hand,
+ arm=arm,
+ position=target_pos,
+ quaternion=target_quat)
+
+ # Check that the initializer succeeds.
+ initializer(physics=physics, random_state=np.random.RandomState(0))
+ self.assertTargetPoseAchieved(
+ physics.bind(hand.tool_center_point), target_pos, target_quat)
+
+ # Confirm that the obstacle and the hand are in contact.
+ self.assertEntitiesInContact(physics, hand, obstacle)
+
+ def test_ignore_collision_not_involving_robot(self):
+ arena, arm, hand = self.make_model()
+ target_pos = np.array([0.1, 0.2, 0.3])
+ target_quat = np.array([0., 1., 1., 0.])
+
+ # Add two boxes that are always in contact with each other, but never with
+ # the arm or hand (since they are not within reach).
+ side_length = 0.1
+ x_offset = 10.
+ bottom_box = props.Primitive(
+ geom_type='box', size=[side_length]*3, pos=[x_offset, 0, 0])
+ top_box = props.Primitive(
+ geom_type='box', size=[side_length]*3, pos=[x_offset, 0, 2*side_length])
+ arena.attach(bottom_box)
+ arena.add_free_entity(top_box)
+
+ initializer = tcp_initializer.ToolCenterPointInitializer(
+ hand=hand,
+ arm=arm,
+ position=target_pos,
+ quaternion=target_quat)
+
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ # Confirm that there are actually contacts between the two boxes.
+ self.assertEntitiesInContact(physics, bottom_box, top_box)
+
+ # Check that the initializer still succeeds.
+ initializer(physics=physics, random_state=np.random.RandomState(0))
+ self.assertTargetPoseAchieved(
+ physics.bind(hand.tool_center_point), target_pos, target_quat)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/initializers/utils.py b/dm_control/composer/initializers/utils.py
new file mode 100644
index 00000000..68ac9646
--- /dev/null
+++ b/dm_control/composer/initializers/utils.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utilities that are helpful for implementing initializers."""
+
+import collections
+
+
+def _get_root_model(mjcf_elements):
+ root_model = mjcf_elements[0].root.root_model
+ for element in mjcf_elements:
+ if element.root.root_model != root_model:
+ raise ValueError('entities do not all belong to the same root model')
+ return root_model
+
+
+class JointStaticIsolator:
+ """Helper class that isolates a collection of MuJoCo joints from others.
+
+ An instance of this class is a context manager that caches the positions and
+ velocities of all non-isolated joints *upon construction*, and resets them to
+ their original state when the context exits.
+ """
+
+ def __init__(self, physics, joints):
+ """Initializes the joint isolator.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ joints: An iterable of `mjcf.Element` representing joints that may be
+ modified inside the context managed by this isolator.
+ """
+ if not isinstance(joints, collections.abc.Iterable):
+ joints = [joints]
+ root_model = _get_root_model(joints)
+ other_joints = [joint for joint in root_model.find_all('joint')
+ if joint not in joints]
+ if other_joints:
+ self._other_joints_mj = physics.bind(other_joints)
+ self._initial_qpos = self._other_joints_mj.qpos.copy()
+ self._initial_qvel = self._other_joints_mj.qvel.copy()
+ else:
+ self._other_joints_mj = None
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ del exc_type, exc_value, traceback # unused
+ if self._other_joints_mj:
+ self._other_joints_mj.qpos = self._initial_qpos
+ self._other_joints_mj.qvel = self._initial_qvel
diff --git a/dm_control/composer/observation/__init__.py b/dm_control/composer/observation/__init__.py
new file mode 100644
index 00000000..dd9cbfce
--- /dev/null
+++ b/dm_control/composer/observation/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Multi-rate observation and buffering framework for Composer environments."""
+
+from dm_control.composer.observation import observable
+from dm_control.composer.observation.obs_buffer import Buffer
+from dm_control.composer.observation.updater import DEFAULT_BUFFER_SIZE
+from dm_control.composer.observation.updater import DEFAULT_DELAY
+from dm_control.composer.observation.updater import DEFAULT_UPDATE_INTERVAL
+from dm_control.composer.observation.updater import Updater
diff --git a/dm_control/composer/observation/fake_physics.py b/dm_control/composer/observation/fake_physics.py
new file mode 100644
index 00000000..a5d05caf
--- /dev/null
+++ b/dm_control/composer/observation/fake_physics.py
@@ -0,0 +1,75 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A fake Physics class for unit testing observation framework."""
+
+import contextlib
+
+from dm_control.composer.observation import observable
+from dm_control.rl import control
+import numpy as np
+
+
+class FakePhysics(control.Physics):
+ """A fake Physics class for unit testing observation framework."""
+
+ def __init__(self):
+ self._step_counter = 0
+ self._observables = {
+ 'twice': observable.Generic(FakePhysics.twice),
+ 'repeated': observable.Generic(FakePhysics.repeated, update_interval=5),
+ 'matrix': observable.Generic(FakePhysics.matrix, update_interval=3)
+ }
+
+ def step(self, sub_steps=1):
+ self._step_counter += 1
+
+ @property
+ def observables(self):
+ return self._observables
+
+ def twice(self):
+ return 2*self._step_counter
+
+ def repeated(self):
+ return [self._step_counter, self._step_counter]
+
+ def sqrt(self):
+ return np.sqrt(self._step_counter)
+
+ def sqrt_plus_one(self):
+ return np.sqrt(self._step_counter) + 1
+
+ def matrix(self):
+ return [[self._step_counter] * 3] * 2
+
+ def time(self):
+ return self._step_counter
+
+ def timestep(self):
+ return 1.0
+
+ def set_control(self, ctrl):
+ pass
+
+ def reset(self):
+ self._step_counter = 0
+
+ def after_reset(self):
+ pass
+
+ @contextlib.contextmanager
+ def suppress_physics_errors(self):
+ yield
diff --git a/dm_control/composer/observation/obs_buffer.py b/dm_control/composer/observation/obs_buffer.py
new file mode 100644
index 00000000..d656b9d7
--- /dev/null
+++ b/dm_control/composer/observation/obs_buffer.py
@@ -0,0 +1,251 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+""""An object that manages the buffering and delaying of observation."""
+
+import collections
+import numpy as np
+
+
+class InFlightObservation:
+ """Represents a delayed observation that may not have arrived yet.
+
+ Attributes:
+ arrival: The time at which this observation will be delivered.
+ timestamp: The time at which this observation was made.
+ delay: The amount of delay between the time at which this observation was
+ made and the time at which it is delivered.
+ value: The value of this observation.
+ """
+
+ __slots__ = ('arrival', 'timestamp', 'delay', 'value')
+
+ def __init__(self, timestamp, delay, value):
+ self.arrival = timestamp + delay
+ self.timestamp = timestamp
+ self.delay = delay
+ self.value = value
+
+ def __lt__(self, other):
+ # This is implemented to facilitate sorting.
+ return self.arrival < other.arrival
+
+
+class Buffer:
+ """An object that manages the buffering and delaying of observation."""
+
+ def __init__(self, buffer_size, shape, dtype, pad_with_initial_value=False,
+ strip_singleton_buffer_dim=False):
+ """Initializes this observation buffer.
+
+ Args:
+ buffer_size: The size of the buffer returned by `read`. Note
+ that this does *not* affect size of the internal buffer held by this
+ object, which always grow as large as is necessary in the presence of
+ large delays.
+ shape: The shape of a single observation held by this buffer, which can
+ either be a single integer or an iterable of integers. The shape of the
+ buffer returned by `read` will then be
+ `(buffer_size, shape[0], ..., shape[n])`, unless `buffer_size == 1`
+ and `strip_singleton_buffer_dim == True`.
+ dtype: The NumPy dtype of observation entries.
+ pad_with_initial_value: (optional) A boolean. If `True` then the buffer
+ returned by `read` is padded with the first observation value when there
+ are fewer observation entries than `buffer_size`. If `False` then the
+ buffer returned by `read` is padded with zeroes.
+ strip_singleton_buffer_dim: (optional) A boolean, if `True` and
+ `buffer_size == 1` then the leading dimension will not be added to the
+ shape of the array returned by `read`.
+ """
+ self._buffer_size = buffer_size
+ try:
+ shape = tuple(shape)
+ except TypeError:
+ if isinstance(shape, int):
+ shape = (shape,)
+ else:
+ raise
+
+ self._has_buffer_dim = not (strip_singleton_buffer_dim and buffer_size == 1)
+ if self._has_buffer_dim:
+ self._buffered_shape = (buffer_size,) + shape
+ else:
+ self._buffered_shape = shape
+ self._dtype = dtype
+
+ # The "arrived" deque contains entries that are due to be delivered now.
+ # This deque should never grow beyond buffer_size.
+ self._arrived_deque = collections.deque(maxlen=buffer_size)
+ if not pad_with_initial_value:
+ for _ in range(buffer_size):
+ self._arrived_deque.append(
+ InFlightObservation(-np.inf, 0, np.full(shape, 0, dtype)))
+
+ # The "pending" deque contains entries that are stored for future delivery.
+ # This deque can grow arbitrarily large in presence of long delays.
+ self._pending_deque = collections.deque()
+
+ def _update_arrived_deque(self, timestamp):
+ while self._pending_deque and self._pending_deque[0].arrival <= timestamp:
+ self._arrived_deque.append(self._pending_deque.popleft())
+
+ @property
+ def shape(self):
+ return self._buffered_shape
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ def insert(self, timestamp, delay, value):
+ """Inserts a new observation to the buffer.
+
+ This function implicitly updates the internal "clock" of this buffer to
+ the timestamp of the new observation, and the internal buffer is trimmed
+ accordingly, i.e. at most `buffer_size` items whose delayed arrival time
+ preceeds `timestamp` are kept.
+
+ Args:
+ timestamp: The time at which this observation was made.
+ delay: The amount of delay between the time at which this observation was
+ made and the time at which it is delivered.
+ value: The value of this observation.
+
+ Raises:
+ ValueError: if `delay` is negative.
+ """
+ # If using `pad_with_initial_value`, the `arrived_deque` would be empty.
+ # We can now pad it with the initial value now.
+ if not self._arrived_deque:
+ for _ in range(self._buffer_size):
+ self._arrived_deque.append(InFlightObservation(-np.inf, 0, value))
+
+ self._update_arrived_deque(timestamp)
+ new_obs = InFlightObservation(timestamp, delay, np.array(value))
+ arrival = new_obs.arrival
+ if delay == 0:
+ # No delay, so the new observation is due for immediate delivery.
+ # Add it to the arrived deque.
+ self._arrived_deque.append(new_obs)
+ elif delay > 0:
+ if not self._pending_deque or arrival > self._pending_deque[-1].arrival:
+ # New observation's arrival time is monotonic.
+ # Technically, we can handle this in the general code branch below,
+ # but since this is assumed to be the "typical" case, the special
+ # handling here saves us from repeatedly allocating and deallocating
+ # an empty temporary deque.
+ self._pending_deque.append(new_obs)
+ else:
+ # General, out-of-order observation.
+ arriving_after_new_obs = collections.deque()
+ while self._pending_deque and arrival < self._pending_deque[-1].arrival:
+ arriving_after_new_obs.appendleft(self._pending_deque.pop())
+ self._pending_deque.append(new_obs)
+ for existing_obs in arriving_after_new_obs:
+ self._pending_deque.append(existing_obs)
+ else:
+ raise ValueError('`delay` should not be negative: '
+ 'got {!r}'.format(delay))
+
+ def read(self, current_time):
+ """Reads the content of the buffer at the given timestamp."""
+ self._update_arrived_deque(current_time)
+ if self._has_buffer_dim:
+ out = np.empty(self._buffered_shape, dtype=self._dtype)
+ for i, obs in enumerate(self._arrived_deque):
+ out[i] = obs.value
+ else:
+ out = self._arrived_deque[0].value.copy()
+ return out
+
+ def drop_unobserved_upcoming_items(self, observation_schedule, read_interval):
+ """Plans an optimal observation schedule for an upcoming control period.
+
+ This function determines which of the proposed upcoming observations will
+ never in fact be delivered and removes them from the observation schedule.
+
+ We assume that observations will only be queried at times that are integer
+ multiples of `read_interval`. If more observations are generated during
+ the upcoming control step than the `buffer_size` of this `Buffer`
+ then of those new observations will never be required. This function takes
+ into account the delayed arrival time and existing buffered items in the
+ planning process.
+
+ Args:
+ observation_schedule: An list of `(timestamp, delay)` tuples, where
+ `timestamp` is the time at which the observation value will be produced,
+ and `delay` is the amount of time the observation will be delayed by.
+ This list will be modified in place.
+ read_interval: The time interval between successive calls to `read`.
+ We assume that observations will only be queried at times that are
+ integer multiples of `read_interval`.
+ """
+ # Private deques to simulate what the deques will look like in the future,
+ # according to the proposed upcoming observation schedule.
+ future_arrived_deque = collections.deque()
+ future_pending_deque = collections.deque()
+
+ # Take existing buffered observations into account when planning the
+ # upcoming schedule.
+ def get_next_existing_timestamp():
+ for obs in reversed(self._pending_deque):
+ yield InFlightObservation(obs.timestamp, obs.delay, None)
+ while True:
+ yield InFlightObservation(-np.inf, 0, None)
+ existing_timestamp_iter = get_next_existing_timestamp()
+ existing_timestamp = next(existing_timestamp_iter)
+
+ # Build the simulated state of the pending deque at the end of the proposed
+ # schedule.
+ sorted_schedule = sorted([InFlightObservation(time[0], time[1], None)
+ for time in observation_schedule])
+ for new_timestamp in reversed(sorted_schedule):
+ # We don't need to worry about any existing item that are delivered before
+ # the first new item, since those are purged independently of our
+ # proposed new observations.
+ while existing_timestamp.arrival > new_timestamp.arrival:
+ future_pending_deque.appendleft(existing_timestamp)
+ existing_timestamp = next(existing_timestamp_iter)
+ future_pending_deque.appendleft(new_timestamp)
+
+ # Find the next timestep at which `read` is called.
+ first_proposed_timestamp = min(t for t, _ in observation_schedule)
+ next_read_time = read_interval * int(np.ceil(
+ first_proposed_timestamp // read_interval))
+
+ # Build the simulated state of the arrived deque at each subsequent
+ # control steps.
+ while future_pending_deque:
+ # Keep track of observations that are delivered for the first time
+ # during this control timestep.
+ newly_arrived = collections.deque()
+ while (future_pending_deque and
+ future_pending_deque[0].arrival <= next_read_time):
+ # `fake_observation` is an `InFlightObservation` without `value`.
+ fake_observation = future_pending_deque.popleft()
+ future_arrived_deque.append(fake_observation)
+ newly_arrived.append(fake_observation)
+ while len(future_arrived_deque) > self._buffer_size:
+ stale = future_arrived_deque.popleft()
+ # Newly-arrived items that become immediately stale are never actually
+ # delivered.
+ if newly_arrived and stale == newly_arrived[0]:
+ newly_arrived.popleft()
+ # `stale` might either be one of the existing pending observations or
+ # from the proposed schedule.
+ if stale.timestamp >= first_proposed_timestamp:
+ observation_schedule.remove((stale.timestamp, stale.delay))
+
+ next_read_time += read_interval
diff --git a/dm_control/composer/observation/obs_buffer_test.py b/dm_control/composer/observation/obs_buffer_test.py
new file mode 100644
index 00000000..d7d083e0
--- /dev/null
+++ b/dm_control/composer/observation/obs_buffer_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for observation.obs_buffer."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.observation import obs_buffer
+import numpy as np
+
+
+def _generate_constant_schedule(update_timestep, delay, control_timestep,
+ n_observed_steps):
+ first = update_timestep
+ last = control_timestep * n_observed_steps + 1
+ return [(i, delay) for i in range(first, last, update_timestep)]
+
+
+class BufferTest(parameterized.TestCase):
+
+ def testOutOfOrderArrival(self):
+ buf = obs_buffer.Buffer(buffer_size=3, shape=(), dtype=float)
+ buf.insert(timestamp=0, delay=4, value=1)
+ buf.insert(timestamp=1, delay=2, value=2)
+ buf.insert(timestamp=2, delay=3, value=3)
+ np.testing.assert_array_equal(buf.read(current_time=2), [0., 0., 0.])
+ np.testing.assert_array_equal(buf.read(current_time=3), [0., 0., 2.])
+ np.testing.assert_array_equal(buf.read(current_time=4), [0., 2., 1.])
+ np.testing.assert_array_equal(buf.read(current_time=5), [2., 1., 3.])
+ np.testing.assert_array_equal(buf.read(current_time=6), [2., 1., 3.])
+
+ @parameterized.parameters(((3, 3),), ((),))
+ def testStripSingletonDimension(self, shape):
+ buf = obs_buffer.Buffer(
+ buffer_size=1,
+ shape=shape,
+ dtype=float,
+ strip_singleton_buffer_dim=True)
+ expected_value = np.full(shape, 42, dtype=float)
+ buf.insert(timestamp=0, delay=0, value=expected_value)
+ np.testing.assert_array_equal(buf.read(current_time=1), expected_value)
+
+ def testPlanToSingleUndelayedObservation(self):
+ buf = obs_buffer.Buffer(buffer_size=1, shape=(), dtype=float)
+ control_timestep = 20
+ observation_schedule = _generate_constant_schedule(
+ update_timestep=1,
+ delay=0,
+ control_timestep=control_timestep,
+ n_observed_steps=1)
+ buf.drop_unobserved_upcoming_items(
+ observation_schedule, read_interval=control_timestep)
+ self.assertEqual(observation_schedule, [(20, 0)])
+
+ def testPlanTwoStepsAhead(self):
+ buf = obs_buffer.Buffer(buffer_size=1, shape=(), dtype=float)
+ control_timestep = 5
+ observation_schedule = _generate_constant_schedule(
+ update_timestep=2,
+ delay=3,
+ control_timestep=control_timestep,
+ n_observed_steps=2)
+ buf.drop_unobserved_upcoming_items(
+ observation_schedule, read_interval=control_timestep)
+ self.assertEqual(observation_schedule, [(2, 3), (6, 3), (10, 3)])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/observation/observable/__init__.py b/dm_control/composer/observation/observable/__init__.py
new file mode 100644
index 00000000..cd61957e
--- /dev/null
+++ b/dm_control/composer/observation/observable/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module for observables in the Composer library."""
+
+from dm_control.composer.observation.observable.base import Generic
+from dm_control.composer.observation.observable.base import MujocoCamera
+from dm_control.composer.observation.observable.base import MujocoFeature
+from dm_control.composer.observation.observable.base import Observable
+
+from dm_control.composer.observation.observable.mjcf import MJCFCamera
+from dm_control.composer.observation.observable.mjcf import MJCFFeature
diff --git a/dm_control/composer/observation/observable/base.py b/dm_control/composer/observation/observable/base.py
new file mode 100644
index 00000000..4d92e64a
--- /dev/null
+++ b/dm_control/composer/observation/observable/base.py
@@ -0,0 +1,309 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Classes representing observables."""
+
+import abc
+import functools
+
+from dm_env import specs
+import numpy as np
+
+
+def _make_aggregator(np_reducer_func, bounds_preserving):
+ result = functools.partial(np_reducer_func, axis=0)
+ setattr(result, 'bounds_reserving', bounds_preserving)
+ return result
+
+
+AGGREGATORS = {
+ 'min': _make_aggregator(np.min, True),
+ 'max': _make_aggregator(np.max, True),
+ 'mean': _make_aggregator(np.mean, True),
+ 'median': _make_aggregator(np.median, True),
+ 'sum': _make_aggregator(np.sum, False),
+}
+
+
+def _get_aggregator(name_or_callable):
+ """Returns aggregator from predefined set by name, else returns callable."""
+ if name_or_callable is None:
+ return None
+ elif not callable(name_or_callable):
+ try:
+ return AGGREGATORS[name_or_callable]
+ except KeyError:
+ raise KeyError('Unrecognized aggregator name: {!r}. Valid names: {}.'
+ .format(name_or_callable, AGGREGATORS.keys()))
+ else:
+ return name_or_callable
+
+
+class Observable(metaclass=abc.ABCMeta):
+ """Abstract base class for an observable."""
+
+ def __init__(self, update_interval, buffer_size, delay,
+ aggregator, corruptor):
+ self._update_interval = update_interval
+ self._buffer_size = buffer_size
+ self._delay = delay
+ self._aggregator = _get_aggregator(aggregator)
+ self._corruptor = corruptor
+ self._enabled = False
+
+ @property
+ def update_interval(self):
+ return self._update_interval
+
+ @update_interval.setter
+ def update_interval(self, value):
+ self._update_interval = value
+
+ @property
+ def buffer_size(self):
+ return self._buffer_size
+
+ @buffer_size.setter
+ def buffer_size(self, value):
+ self._buffer_size = value
+
+ @property
+ def delay(self):
+ return self._delay
+
+ @delay.setter
+ def delay(self, value):
+ self._delay = value
+
+ @property
+ def aggregator(self):
+ return self._aggregator
+
+ @aggregator.setter
+ def aggregator(self, value):
+ self._aggregator = _get_aggregator(value)
+
+ @property
+ def corruptor(self):
+ return self._corruptor
+
+ @corruptor.setter
+ def corruptor(self, value):
+ self._corruptor = value
+
+ @property
+ def enabled(self):
+ return self._enabled
+
+ @enabled.setter
+ def enabled(self, value):
+ self._enabled = value
+
+ @property
+ def array_spec(self):
+ """The `ArraySpec` which describes observation arrays from this observable.
+
+ If this property is `None`, then the specification should be inferred by
+ actually retrieving an observation from this observable.
+ """
+ return None
+
+ @abc.abstractmethod
+ def _callable(self, physics):
+ pass
+
+ def observation_callable(self, physics, random_state=None):
+ """A callable which returns a (potentially corrupted) observation."""
+ raw_callable = self._callable(physics)
+ if self._corruptor:
+ def _corrupted():
+ return self._corruptor(raw_callable(), random_state=random_state)
+ return _corrupted
+ else:
+ return raw_callable
+
+ def __call__(self, physics, random_state=None):
+ """Convenience function to just call an observable."""
+ return self.observation_callable(physics, random_state)()
+
+ def configure(self, **kwargs):
+ """Sets multiple attributes of this observable.
+
+ Args:
+ **kwargs: The keyword argument names correspond to the attributes
+ being modified.
+ Raises:
+ AttributeError: If kwargs contained an attribute not in the observable.
+ """
+ for key, value in kwargs.items():
+ if not hasattr(self, key):
+ raise AttributeError('Cannot add attribute %s in configure.' % key)
+ self.__setattr__(key, value)
+
+
+class Generic(Observable):
+ """A generic observable defined via a callable."""
+
+ def __init__(self, raw_observation_callable, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None):
+ """Initializes this observable.
+
+ Args:
+ raw_observation_callable: A callable which accepts a single argument of
+ type `control.base.Physics` and returns the observation value.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ """
+ self._raw_callable = raw_observation_callable
+ super().__init__(update_interval, buffer_size, delay, aggregator, corruptor)
+
+ def _callable(self, physics):
+ return lambda: self._raw_callable(physics)
+
+
+class MujocoFeature(Observable):
+ """An observable corresponding to a named MuJoCo feature."""
+
+ def __init__(self, kind, feature_name, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None):
+ """Initializes this observable.
+
+ Args:
+ kind: A string corresponding to a field name in MuJoCo's mjData struct.
+ feature_name: A string, or list of strings, or a callable returning
+ either, corresponding to the name(s) of an entity in the
+ MuJoCo XML model.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ """
+ self._kind = kind
+ self._feature_name = feature_name
+ super().__init__(update_interval, buffer_size, delay, aggregator, corruptor)
+
+ def _callable(self, physics):
+ named_indexer_for_kind = physics.named.data.__getattribute__(self._kind)
+ if callable(self._feature_name):
+ return lambda: named_indexer_for_kind[self._feature_name()]
+ else:
+ return lambda: named_indexer_for_kind[self._feature_name]
+
+
+class MujocoCamera(Observable):
+ """An observable corresponding to a MuJoCo camera."""
+
+ def __init__(self, camera_name, height=240, width=320, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None, depth=False):
+ """Initializes this observable.
+
+ Args:
+ camera_name: A string corresponding to the name of a camera in the
+ MuJoCo XML model.
+ height: (optional) An integer, the height of the rendered image.
+ width: (optional) An integer, the width of the rendered image.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ depth: (optional) A boolean. If `True`, renders a depth image (1-channel)
+ instead of RGB (3-channel).
+ """
+ self._camera_name = camera_name
+ self._height = height
+ self._width = width
+
+ self._n_channels = 1 if depth else 3
+ self._dtype = np.float32 if depth else np.uint8
+ self._depth = depth
+ super().__init__(update_interval, buffer_size, delay, aggregator, corruptor)
+
+ @property
+ def height(self):
+ return self._height
+
+ @height.setter
+ def height(self, value):
+ self._height = value
+
+ @property
+ def width(self):
+ return self._width
+
+ @width.setter
+ def width(self, value):
+ self._width = value
+
+ @property
+ def array_spec(self):
+ return specs.Array(
+ shape=(self._height, self._width, self._n_channels), dtype=self._dtype)
+
+ def _callable(self, physics):
+ return lambda: physics.render( # pylint: disable=g-long-lambda
+ self._height, self._width, self._camera_name, depth=self._depth)
diff --git a/dm_control/composer/observation/observable/base_test.py b/dm_control/composer/observation/observable/base_test.py
new file mode 100644
index 00000000..6285dae2
--- /dev/null
+++ b/dm_control/composer/observation/observable/base_test.py
@@ -0,0 +1,147 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for observable classes."""
+
+from absl.testing import absltest
+from dm_control import mujoco
+from dm_control.composer.observation import fake_physics
+from dm_control.composer.observation.observable import base
+import numpy as np
+
+
+_MJCF = """
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+
+class _FakeBaseObservable(base.Observable):
+
+ def _callable(self, physics):
+ pass
+
+
+class ObservableTest(absltest.TestCase):
+
+ def testBaseProperties(self):
+ fake_observable = _FakeBaseObservable(update_interval=42,
+ buffer_size=5,
+ delay=10,
+ aggregator=None,
+ corruptor=None)
+ self.assertEqual(fake_observable.update_interval, 42)
+ self.assertEqual(fake_observable.buffer_size, 5)
+ self.assertEqual(fake_observable.delay, 10)
+
+ fake_observable.update_interval = 48
+ self.assertEqual(fake_observable.update_interval, 48)
+
+ fake_observable.buffer_size = 7
+ self.assertEqual(fake_observable.buffer_size, 7)
+
+ fake_observable.delay = 13
+ self.assertEqual(fake_observable.delay, 13)
+
+ enabled = not fake_observable.enabled
+ fake_observable.enabled = not fake_observable.enabled
+ self.assertEqual(fake_observable.enabled, enabled)
+
+ def testGeneric(self):
+ physics = fake_physics.FakePhysics()
+ repeated_observable = base.Generic(
+ fake_physics.FakePhysics.repeated, update_interval=42)
+ repeated_observation = repeated_observable.observation_callable(physics)()
+ self.assertEqual(repeated_observable.update_interval, 42)
+ np.testing.assert_array_equal(repeated_observation, [0, 0])
+
+ def testMujocoFeature(self):
+ physics = mujoco.Physics.from_xml_string(_MJCF)
+
+ hinge_observable = base.MujocoFeature(
+ kind='qpos', feature_name='my_hinge')
+ hinge_observation = hinge_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ hinge_observation, physics.named.data.qpos['my_hinge'])
+
+ box_observable = base.MujocoFeature(
+ kind='geom_xpos', feature_name='small_sphere', update_interval=5)
+ box_observation = box_observable.observation_callable(physics)()
+ self.assertEqual(box_observable.update_interval, 5)
+ np.testing.assert_array_equal(
+ box_observation, physics.named.data.geom_xpos['small_sphere'])
+
+ observable_from_callable = base.MujocoFeature(
+ kind='geom_xpos', feature_name=lambda: ['my_box', 'small_sphere'])
+ observation_from_callable = (
+ observable_from_callable.observation_callable(physics)())
+ np.testing.assert_array_equal(
+ observation_from_callable,
+ physics.named.data.geom_xpos[['my_box', 'small_sphere']])
+
+ def testMujocoCamera(self):
+ physics = mujoco.Physics.from_xml_string(_MJCF)
+
+ camera_observable = base.MujocoCamera(
+ camera_name='world', height=480, width=640, update_interval=7)
+ self.assertEqual(camera_observable.update_interval, 7)
+ camera_observation = camera_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(480, 640, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ camera_observable.height = 300
+ camera_observable.width = 400
+ camera_observation = camera_observable.observation_callable(physics)()
+ self.assertEqual(camera_observable.height, 300)
+ self.assertEqual(camera_observable.width, 400)
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(300, 400, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ def testCorruptor(self):
+ physics = fake_physics.FakePhysics()
+ def add_twelve(old_value, random_state):
+ del random_state # Unused.
+ return [x + 12 for x in old_value]
+ repeated_observable = base.Generic(
+ fake_physics.FakePhysics.repeated, corruptor=add_twelve)
+ corrupted = repeated_observable.observation_callable(
+ physics=physics, random_state=None)()
+ np.testing.assert_array_equal(corrupted, [12, 12])
+
+ def testInvalidAggregatorName(self):
+ name = 'invalid_name'
+ with self.assertRaisesRegex(KeyError, 'Unrecognized aggregator name'):
+ _ = _FakeBaseObservable(update_interval=3, buffer_size=2, delay=1,
+ aggregator=name, corruptor=None)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/observation/observable/mjcf.py b/dm_control/composer/observation/observable/mjcf.py
new file mode 100644
index 00000000..c591d7e7
--- /dev/null
+++ b/dm_control/composer/observation/observable/mjcf.py
@@ -0,0 +1,276 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Observables that are defined in terms of MJCF elements."""
+
+import collections
+
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base
+from dm_env import specs
+import numpy as np
+
+
+_BOTH_SEGMENTATION_AND_DEPTH_ENABLED = (
+ '`segmentation` and `depth` cannot both be `True`.')
+
+
+def _check_mjcf_element(obj):
+ if not isinstance(obj, mjcf.Element):
+ raise ValueError(
+ 'expected an `mjcf.Element`, got type {}: {}'.format(type(obj), obj))
+
+
+def _check_mjcf_element_iterable(obj_iterable):
+ if not isinstance(obj_iterable, collections.abc.Iterable):
+ obj_iterable = (obj_iterable,)
+ for obj in obj_iterable:
+ _check_mjcf_element(obj)
+
+
+class MJCFFeature(base.Observable):
+ """An observable corresponding to an element in an MJCF model."""
+
+ def __init__(self, kind, mjcf_element, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None, index=None):
+ """Initializes this observable.
+
+ Args:
+ kind: The name of an attribute of a bound `mjcf.Physics` instance. See the
+ docstring for `mjcf.Physics.bind()` for examples showing this syntax.
+ mjcf_element: An `mjcf.Element`, or iterable of `mjcf.Element`.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ index: (optional) An index that is to be applied to an array attribute
+ to pick out a slice or particular items. As a syntactic sugar,
+ `MJCFFeature` also implements `__getitem__` that returns a copy of the
+ same observable with an index applied.
+
+ Raises:
+ ValueError: if `mjcf_element` is not an `mjcf.Element`.
+ """
+ _check_mjcf_element_iterable(mjcf_element)
+ self._kind = kind
+ self._mjcf_element = mjcf_element
+ self._index = index
+ super().__init__(update_interval, buffer_size, delay, aggregator, corruptor)
+
+ def _callable(self, physics):
+ binding = physics.bind(self._mjcf_element)
+ if self._index is not None:
+ return lambda: getattr(binding, self._kind)[self._index]
+ else:
+ return lambda: getattr(binding, self._kind)
+
+ def __getitem__(self, key):
+ if self._index is not None:
+ raise NotImplementedError(
+ 'slicing an already-sliced MJCFFeature observable is not supported')
+ return MJCFFeature(self._kind, self._mjcf_element, self._update_interval,
+ self._buffer_size, self._delay, self._aggregator,
+ self._corruptor, key)
+
+
+class MJCFCamera(base.Observable):
+ """An observable corresponding to a camera in an MJCF model."""
+
+ def __init__(self,
+ mjcf_element,
+ height=240,
+ width=320,
+ update_interval=1,
+ buffer_size=None,
+ delay=None,
+ aggregator=None,
+ corruptor=None,
+ depth=False,
+ segmentation=False,
+ scene_option=None,
+ render_flag_overrides=None):
+ """Initializes this observable.
+
+ Args:
+ mjcf_element: A `mjcf.Element`.
+ height: (optional) An integer, the height of the rendered image.
+ width: (optional) An integer, the width of the rendered image.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ depth: (optional) A boolean. If `True`, renders a depth image (1-channel)
+ instead of RGB (3-channel).
+ segmentation: (optional) A boolean. If `True`, renders a segmentation mask
+ (2-channel, int32) labeling the objects in the scene with their
+ (mjModel ID, mjtObj enum object type) pair. Background pixels are
+ set to (-1, -1).
+ scene_option: An optional `wrapper.MjvOption` instance that can be used to
+ render the scene with custom visualization options. If None then the
+ default options will be used.
+ render_flag_overrides: Optional mapping specifying rendering flags to
+ override. The keys can be either lowercase strings or `mjtRndFlag` enum
+ values, and the values are the overridden flag values, e.g.
+ `{'wireframe': True}` or `{mujoco.mjtRndFlag.mjRND_WIREFRAME: True}`.
+ See `mujoco.mjtRndFlag` for the set of valid flags. Must be None if
+ either `depth` or `segmentation` is True.
+
+ Raises:
+ ValueError: if `mjcf_element` is not a element.
+ ValueError: if segmentation and depth flags are both set to True.
+ """
+ _check_mjcf_element(mjcf_element)
+ if mjcf_element.tag != 'camera':
+ raise ValueError(
+ 'expected a element: got {}'.format(mjcf_element))
+ self._mjcf_element = mjcf_element
+ self._height = height
+ self._width = width
+
+ if segmentation and depth:
+ raise ValueError(_BOTH_SEGMENTATION_AND_DEPTH_ENABLED)
+ if segmentation:
+ self._dtype = np.int32
+ self._n_channels = 2
+ elif depth:
+ self._dtype = np.float32
+ self._n_channels = 1
+ else:
+ self._dtype = np.uint8
+ self._n_channels = 3
+ self._depth = depth
+ self._segmentation = segmentation
+ self._scene_option = scene_option
+ self._render_flag_overrides = render_flag_overrides
+ super().__init__(update_interval, buffer_size, delay, aggregator, corruptor)
+
+ @property
+ def height(self):
+ return self._height
+
+ @height.setter
+ def height(self, value):
+ self._height = value
+
+ @property
+ def width(self):
+ return self._width
+
+ @width.setter
+ def width(self, value):
+ self._width = value
+
+ @property
+ def depth(self):
+ return self._depth
+
+ @depth.setter
+ def depth(self, value):
+ self._depth = value
+
+ @property
+ def segmentation(self):
+ return self._segmentation
+
+ @segmentation.setter
+ def segmentation(self, value):
+ self._segmentation = value
+
+ @property
+ def scene_option(self):
+ return self._scene_option
+
+ @scene_option.setter
+ def scene_option(self, value):
+ self._scene_option = value
+
+ @property
+ def render_flag_overrides(self):
+ return self._render_flag_overrides
+
+ @render_flag_overrides.setter
+ def render_flag_overrides(self, value):
+ self._render_flag_overrides = value
+
+ @property
+ def array_spec(self):
+ if self._depth:
+ # Note that these are loose bounds - the exact bounds are given by:
+ # extent*(znear, zfar), however the values of these parameters are unknown
+ # since we don't have access to the compiled model within this method.
+ minimum = 0.0
+ maximum = np.inf
+ elif self._segmentation:
+ # -1 denotes background pixels. See dm_control.mujoco.Camera.render for
+ # further details.
+ minimum = -1
+ maximum = np.iinfo(self._dtype).max
+ else:
+ minimum = np.iinfo(self._dtype).min
+ maximum = np.iinfo(self._dtype).max
+
+ return specs.BoundedArray(
+ minimum=minimum,
+ maximum=maximum,
+ shape=(self._height, self._width, self._n_channels),
+ dtype=self._dtype)
+
+ def _callable(self, physics):
+
+ def get_observation():
+ pixels = physics.render(
+ height=self._height,
+ width=self._width,
+ camera_id=self._mjcf_element.full_identifier,
+ depth=self._depth,
+ segmentation=self._segmentation,
+ scene_option=self._scene_option,
+ render_flag_overrides=self._render_flag_overrides)
+ return np.atleast_3d(pixels)
+
+ return get_observation
diff --git a/dm_control/composer/observation/observable/mjcf_test.py b/dm_control/composer/observation/observable/mjcf_test.py
new file mode 100644
index 00000000..a29d27c4
--- /dev/null
+++ b/dm_control/composer/observation/observable/mjcf_test.py
@@ -0,0 +1,174 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for mjcf observables."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer.observation.observable import mjcf as mjcf_observable
+from dm_env import specs
+import numpy as np
+
+_MJCF = """
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+
+class ObservableTest(parameterized.TestCase):
+
+ def testMJCFFeature(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+
+ my_hinge = mjcf_root.find('joint', 'my_hinge')
+ hinge_observable = mjcf_observable.MJCFFeature(
+ kind='qpos', mjcf_element=my_hinge)
+ hinge_observation = hinge_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ hinge_observation, physics.named.data.qpos[my_hinge.full_identifier])
+
+ small_sphere = mjcf_root.find('geom', 'small_sphere')
+ sphere_observable = mjcf_observable.MJCFFeature(
+ kind='xpos', mjcf_element=small_sphere, update_interval=5)
+ sphere_observation = sphere_observable.observation_callable(physics)()
+ self.assertEqual(sphere_observable.update_interval, 5)
+ np.testing.assert_array_equal(
+ sphere_observation, physics.named.data.geom_xpos[
+ small_sphere.full_identifier])
+
+ my_box = mjcf_root.find('geom', 'my_box')
+ list_observable = mjcf_observable.MJCFFeature(
+ kind='xpos', mjcf_element=[my_box, small_sphere])
+ list_observation = (
+ list_observable.observation_callable(physics)())
+ np.testing.assert_array_equal(
+ list_observation,
+ physics.named.data.geom_xpos[[my_box.full_identifier,
+ small_sphere.full_identifier]])
+
+ with self.assertRaisesRegex(ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFFeature('qpos', 'my_hinge')
+ with self.assertRaisesRegex(ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFFeature('geom_xpos', [my_box, 'small_sphere'])
+
+ def testMJCFFeatureIndex(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+
+ small_sphere = mjcf_root.find('geom', 'small_sphere')
+ sphere_xmat = np.array(
+ physics.named.data.geom_xmat[small_sphere.full_identifier])
+
+ observable_xrow = mjcf_observable.MJCFFeature(
+ 'xmat', small_sphere, index=[1, 3, 5, 7])
+ np.testing.assert_array_equal(
+ observable_xrow.observation_callable(physics)(),
+ sphere_xmat[[1, 3, 5, 7]])
+
+ observable_yyzz = mjcf_observable.MJCFFeature('xmat', small_sphere)[2:6]
+ np.testing.assert_array_equal(
+ observable_yyzz.observation_callable(physics)(), sphere_xmat[2:6])
+
+ def testMJCFCamera(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+
+ camera = mjcf_root.find('camera', 'world')
+ camera_observable = mjcf_observable.MJCFCamera(
+ mjcf_element=camera, height=480, width=640, update_interval=7)
+ self.assertEqual(camera_observable.update_interval, 7)
+ camera_observation = camera_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(480, 640, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ camera_observable.height = 300
+ camera_observable.width = 400
+ camera_observation = camera_observable.observation_callable(physics)()
+ self.assertEqual(camera_observable.height, 300)
+ self.assertEqual(camera_observable.width, 400)
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(300, 400, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ with self.assertRaisesRegex(ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFCamera('world')
+ with self.assertRaisesRegex(ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFCamera([camera])
+ with self.assertRaisesRegex(ValueError, 'expected a '):
+ mjcf_observable.MJCFCamera(mjcf_root.find('body', 'body'))
+
+ @parameterized.parameters(
+ dict(camera_type='rgb', channels=3, dtype=np.uint8,
+ minimum=0, maximum=255),
+ dict(camera_type='depth', channels=1, dtype=np.float32,
+ minimum=0., maximum=np.inf),
+ dict(camera_type='segmentation', channels=2, dtype=np.int32,
+ minimum=-1, maximum=np.iinfo(np.int32).max),
+ )
+ def testMJCFCameraSpecs(self, camera_type, channels, dtype, minimum, maximum):
+ width = 640
+ height = 480
+ shape = (height, width, channels)
+ expected_spec = specs.BoundedArray(
+ shape=shape, dtype=dtype, minimum=minimum, maximum=maximum)
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ camera = mjcf_root.find('camera', 'world')
+ observable_kwargs = {} if camera_type == 'rgb' else {camera_type: True}
+ camera_observable = mjcf_observable.MJCFCamera(
+ mjcf_element=camera, height=height, width=width, update_interval=7,
+ **observable_kwargs)
+ self.assertEqual(camera_observable.array_spec, expected_spec)
+
+ def testMJCFSegCamera(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+ camera = mjcf_root.find('camera', 'world')
+ camera_observable = mjcf_observable.MJCFCamera(
+ mjcf_element=camera, height=480, width=640, update_interval=7,
+ segmentation=True)
+ self.assertEqual(camera_observable.update_interval, 7)
+ camera_observation = camera_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ camera_observation,
+ physics.render(480, 640, 'world', segmentation=True))
+ camera_observable.array_spec.validate(camera_observation)
+
+ def testErrorIfSegmentationAndDepthBothEnabled(self):
+ camera = mjcf.from_xml_string(_MJCF).find('camera', 'world')
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, mjcf_observable._BOTH_SEGMENTATION_AND_DEPTH_ENABLED):
+ mjcf_observable.MJCFCamera(mjcf_element=camera, segmentation=True,
+ depth=True)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/observation/updater.py b/dm_control/composer/observation/updater.py
new file mode 100644
index 00000000..9b145389
--- /dev/null
+++ b/dm_control/composer/observation/updater.py
@@ -0,0 +1,331 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An object that creates and updates buffers for enabled observables."""
+
+import collections
+import functools
+from absl import logging
+
+from dm_control.composer import variation
+from dm_control.composer.observation import obs_buffer
+from dm_env import specs
+import numpy as np
+
+DEFAULT_BUFFER_SIZE = 1
+DEFAULT_UPDATE_INTERVAL = 1
+DEFAULT_DELAY = 0
+
+
+class _EnabledObservable:
+ """Encapsulates an enabled observable, its buffer, and its update schedule."""
+
+ __slots__ = ('observable', 'observation_callable',
+ 'update_interval', 'delay', 'buffer_size',
+ 'buffer', 'update_schedule')
+
+ def __init__(self, observable, physics, random_state,
+ strip_singleton_buffer_dim, pad_with_initial_value):
+ self.observable = observable
+ self.observation_callable = (
+ observable.observation_callable(physics, random_state))
+
+ self._bind_attribute_from_observable('update_interval',
+ DEFAULT_UPDATE_INTERVAL,
+ random_state)
+ self._bind_attribute_from_observable('delay',
+ DEFAULT_DELAY,
+ random_state)
+ self._bind_attribute_from_observable('buffer_size',
+ DEFAULT_BUFFER_SIZE,
+ random_state)
+
+ obs_spec = self.observable.array_spec
+ if obs_spec is None:
+ # We take an observation to determine the shape and dtype of the array.
+ # This occurs outside of an episode and doesn't affect environment
+ # behavior. At this point the physics state is not guaranteed to be valid,
+ # so we might get a `PhysicsError` if the observation callable calls
+ # `physics.forward`. We suppress such errors since they do not matter as
+ # far as the shape and dtype of the observation are concerned.
+ with physics.suppress_physics_errors():
+ obs_array = self.observation_callable()
+ obs_array = np.asarray(obs_array)
+ obs_spec = specs.Array(shape=obs_array.shape, dtype=obs_array.dtype)
+ self.buffer = obs_buffer.Buffer(
+ buffer_size=self.buffer_size,
+ shape=obs_spec.shape, dtype=obs_spec.dtype,
+ pad_with_initial_value=pad_with_initial_value,
+ strip_singleton_buffer_dim=strip_singleton_buffer_dim)
+ self.update_schedule = collections.deque()
+
+ def _bind_attribute_from_observable(self, attr, default_value, random_state):
+ obs_attr = getattr(self.observable, attr)
+ if obs_attr:
+ if isinstance(obs_attr, variation.Variation):
+ setattr(self, attr,
+ functools.partial(obs_attr, random_state=random_state))
+ else:
+ setattr(self, attr, obs_attr)
+ else:
+ setattr(self, attr, default_value)
+
+
+def _call_if_callable(arg):
+ if callable(arg):
+ return arg()
+ else:
+ return arg
+
+
+def _validate_structure(structure):
+ """Validates the structure of the given observables collection.
+
+ The collection must either be a dict, or a (list or tuple) of dicts.
+
+ Args:
+ structure: A candidate collection of observables.
+
+ Returns:
+ A boolean that is `True` if `structure` is either a list or a tuple, or
+ `False` otherwise.
+
+ Raises:
+ ValueError: If `structure` is neither a dict nor a (list or tuple) of dicts.
+ """
+ is_nested = isinstance(structure, (list, tuple))
+ if is_nested:
+ is_valid = all(isinstance(obj, dict) for obj in structure)
+ else:
+ is_valid = isinstance(structure, dict)
+ if not is_valid:
+ raise ValueError(
+ '`observables` should be a dict, or a (list or tuple) of dicts'
+ ': got {}'.format(structure))
+ return is_nested
+
+
+class Updater:
+ """Creates and updates buffers for enabled observables."""
+
+ def __init__(self, observables, physics_steps_per_control_step=1,
+ strip_singleton_buffer_dim=False,
+ pad_with_initial_value=False):
+ self._physics_steps_per_control_step = physics_steps_per_control_step
+ self._strip_singleton_buffer_dim = strip_singleton_buffer_dim
+ self._pad_with_initial_value = pad_with_initial_value
+ self._step_counter = 0
+ self._observables = observables
+ self._is_nested = _validate_structure(observables)
+ self._enabled_structure = None
+ self._enabled_list = None
+
+ def reset(self, physics, random_state):
+ """Resets this updater's state."""
+
+ def make_buffers_dict(observables):
+ """Makes observable states in a dict."""
+ # Use `type(observables)` so that our output structure respects the
+ # original dict subclass (e.g. OrderedDict).
+ out_dict = type(observables)()
+ for key, value in observables.items():
+ if value.enabled:
+ out_dict[key] = _EnabledObservable(value, physics, random_state,
+ self._strip_singleton_buffer_dim,
+ self._pad_with_initial_value)
+ return out_dict
+
+ if self._is_nested:
+ self._enabled_structure = type(self._observables)(
+ make_buffers_dict(obs_dict) for obs_dict in self._observables)
+ self._enabled_list = []
+ for enabled_dict in self._enabled_structure:
+ self._enabled_list.extend(enabled_dict.values())
+ else:
+ self._enabled_structure = make_buffers_dict(self._observables)
+ self._enabled_list = self._enabled_structure.values()
+
+ self._step_counter = 0
+ for enabled in self._enabled_list:
+ first_delay = _call_if_callable(enabled.delay)
+ enabled.buffer.insert(
+ 0, first_delay,
+ enabled.observation_callable())
+
+ def observation_spec(self):
+ """The observation specification for this environment.
+
+ Returns a dict mapping the names of enabled observations to their
+ corresponding `Array` or `BoundedArray` specs.
+
+ If an obs has a BoundedArray spec, but uses an aggregator that
+ does not preserve those bounds (such as `sum`), it will be mapped to an
+ (unbounded) `Array` spec. If using a bounds-preserving custom aggregator
+ `my_agg`, give it an attribute `my_agg.preserves_bounds = True` to indicate
+ to this method that it is bounds-preserving.
+
+ The returned specification is only valid as of the previous call
+ to `reset`. In particular, it is an error to call this function before
+ the first call to `reset`.
+
+ Returns:
+ A dict mapping observation name to `Array` or `BoundedArray` spec
+ containing the observation shape and dtype, and possibly bounds.
+
+ Raises:
+ RuntimeError: If this method is called before `reset` has been called.
+ """
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `observation_spec`.')
+
+ def make_observation_spec_dict(enabled_dict):
+ """Makes a dict of enabled observation specs from of observables."""
+ out_dict = type(enabled_dict)()
+ for name, enabled in enabled_dict.items():
+
+ if (enabled.observable.aggregator is None
+ and enabled.observable.array_spec is not None):
+ # If possible, keep the original array spec, just updating the name
+ # and modifying the dimension for buffering. Doing this allows for
+ # custom spec types to be exposed by the environment where possible.
+ out_dict[name] = enabled.observable.array_spec.replace(
+ name=name, shape=enabled.buffer.shape
+ )
+ continue
+
+ if isinstance(enabled.observable.array_spec, specs.BoundedArray):
+ bounds = (enabled.observable.array_spec.minimum,
+ enabled.observable.array_spec.maximum)
+ else:
+ bounds = None
+
+ if enabled.observable.aggregator:
+ aggregator = enabled.observable.aggregator
+ aggregated = aggregator(np.zeros(enabled.buffer.shape,
+ dtype=enabled.buffer.dtype))
+ shape = aggregated.shape
+ dtype = aggregated.dtype
+
+ # Ditch bounds if the aggregator isn't known to be bounds-preserving.
+ if bounds:
+ if not hasattr(aggregator, 'preserves_bounds'):
+ logging.warning('Ignoring the bounds of this observable\'s spec, '
+ 'as its aggregator method has no boolean '
+ '`preserves_bounds` attrubute.')
+ bounds = None
+ elif not aggregator.preserves_bounds:
+ bounds = None
+ else:
+ shape = enabled.buffer.shape
+ dtype = enabled.buffer.dtype
+
+ if bounds:
+ spec = specs.BoundedArray(minimum=bounds[0],
+ maximum=bounds[1],
+ shape=shape,
+ dtype=dtype,
+ name=name)
+ else:
+ spec = specs.Array(shape=shape, dtype=dtype, name=name)
+
+ out_dict[name] = spec
+ return out_dict
+
+ if self._is_nested:
+ enabled_specs = type(self._enabled_structure)(
+ make_observation_spec_dict(enabled_dict)
+ for enabled_dict in self._enabled_structure)
+ else:
+ enabled_specs = make_observation_spec_dict(self._enabled_structure)
+
+ return enabled_specs
+
+ def prepare_for_next_control_step(self):
+ """Simulates the next control step and optimizes the update schedule."""
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `before_step`.')
+ for enabled in self._enabled_list:
+
+ if (enabled.update_interval == DEFAULT_UPDATE_INTERVAL
+ and enabled.delay == DEFAULT_DELAY
+ and enabled.buffer_size < self._physics_steps_per_control_step):
+ for i in reversed(range(enabled.buffer_size)):
+ next_step = (
+ self._step_counter + self._physics_steps_per_control_step - i)
+ next_delay = DEFAULT_DELAY
+ enabled.update_schedule.append((next_step, next_delay))
+ else:
+ if enabled.update_schedule:
+ last_scheduled_step = enabled.update_schedule[-1][0]
+ else:
+ last_scheduled_step = self._step_counter
+ max_step = self._step_counter + 2 * self._physics_steps_per_control_step
+ while last_scheduled_step < max_step:
+ next_update_interval = _call_if_callable(enabled.update_interval)
+ next_step = last_scheduled_step + next_update_interval
+ next_delay = _call_if_callable(enabled.delay)
+ enabled.update_schedule.append((next_step, next_delay))
+ last_scheduled_step = next_step
+ # Optimize the schedule by planning ahead and dropping unseen entries.
+ enabled.buffer.drop_unobserved_upcoming_items(
+ enabled.update_schedule, self._physics_steps_per_control_step)
+
+ def update(self):
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `after_substep`.')
+ self._step_counter += 1
+ for enabled in self._enabled_list:
+ if (enabled.update_schedule and
+ enabled.update_schedule[0][0] == self._step_counter):
+ timestamp, delay = enabled.update_schedule.popleft()
+ enabled.buffer.insert(
+ timestamp, delay,
+ enabled.observation_callable())
+
+ def get_observation(self):
+ """Gets the current observation.
+
+ The returned observation is only valid as of the previous call
+ to `reset`. In particular, it is an error to call this function before
+ the first call to `reset`.
+
+ Returns:
+ A dict, or list of dicts, or tuple of dicts, of observation values.
+ The returned structure corresponds to the structure of the `observables`
+ that was given at initialization time.
+
+ Raises:
+ RuntimeError: If this method is called before `reset` has been called.
+ """
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `observation`.')
+
+ def aggregate_dict(enabled_dict):
+ out_dict = type(enabled_dict)()
+ for name, enabled in enabled_dict.items():
+ if enabled.observable.aggregator:
+ aggregated = enabled.observable.aggregator(
+ enabled.buffer.read(self._step_counter))
+ else:
+ aggregated = enabled.buffer.read(self._step_counter)
+ out_dict[name] = aggregated
+ return out_dict
+
+ if self._is_nested:
+ return type(self._enabled_structure)(
+ aggregate_dict(enabled_dict)
+ for enabled_dict in self._enabled_structure)
+ else:
+ return aggregate_dict(self._enabled_structure)
diff --git a/dm_control/composer/observation/updater_test.py b/dm_control/composer/observation/updater_test.py
new file mode 100644
index 00000000..a524b804
--- /dev/null
+++ b/dm_control/composer/observation/updater_test.py
@@ -0,0 +1,304 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for observation.observation_updater."""
+
+import collections
+import itertools
+import math
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.observation import fake_physics
+from dm_control.composer.observation import observable
+from dm_control.composer.observation import updater
+from dm_env import specs
+import numpy as np
+
+
+class DeterministicSequence:
+
+ def __init__(self, sequence):
+ self._iter = itertools.cycle(sequence)
+
+ def __call__(self, random_state=None):
+ del random_state # unused
+ return next(self._iter)
+
+
+class BoundedGeneric(observable.Generic):
+
+ def __init__(self, raw_observation_callable, minimum, maximum, **kwargs):
+ super().__init__(
+ raw_observation_callable=raw_observation_callable, **kwargs)
+ self._bounds = (minimum, maximum)
+
+ @property
+ def array_spec(self):
+ datum = np.array(self(None, None))
+ return specs.BoundedArray(shape=datum.shape,
+ dtype=datum.dtype,
+ minimum=self._bounds[0],
+ maximum=self._bounds[1])
+
+
+class MyArraySpec(specs.Array):
+ pass
+
+
+class GenericObservableWithMyArraySpec(observable.Generic):
+ @property
+ def array_spec(self):
+ datum = np.array(self(None, None))
+ return MyArraySpec(shape=datum.shape, dtype=datum.dtype)
+
+
+class UpdaterTest(parameterized.TestCase):
+
+ @parameterized.parameters(list, tuple)
+ def testNestedSpecsAndValues(self, list_or_tuple):
+ observables = list_or_tuple((
+ {'one': observable.Generic(lambda _: 1.),
+ 'two': observable.Generic(lambda _: [2, 2]),
+ }, collections.OrderedDict([
+ ('three', observable.Generic(lambda _: np.full((2, 2), 3))),
+ ('four', observable.Generic(lambda _: [4.])),
+ ('five', observable.Generic(lambda _: 5)),
+ ('six', BoundedGeneric(lambda _: [2, 2], 1, 4)),
+ ('seven', BoundedGeneric(lambda _: 2, 1, 4, aggregator='sum')),
+ ])
+ ))
+
+ observables[0]['two'].enabled = True
+ observables[1]['three'].enabled = True
+ observables[1]['five'].enabled = True
+ observables[1]['six'].enabled = True
+ observables[1]['seven'].enabled = True
+
+ observation_updater = updater.Updater(observables)
+ observation_updater.reset(physics=fake_physics.FakePhysics(),
+ random_state=None)
+
+ def make_spec(obs):
+ array = np.array(obs.observation_callable(None, None)())
+ shape = array.shape if obs.aggregator else (1,) + array.shape
+
+ if (isinstance(obs, BoundedGeneric) and
+ obs.aggregator is not observable.base.AGGREGATORS['sum']):
+ return specs.BoundedArray(shape=shape,
+ dtype=array.dtype,
+ minimum=obs.array_spec.minimum,
+ maximum=obs.array_spec.maximum)
+ else:
+ return specs.Array(shape=shape, dtype=array.dtype)
+
+ expected_specs = list_or_tuple((
+ {'two': make_spec(observables[0]['two'])},
+ collections.OrderedDict([
+ ('three', make_spec(observables[1]['three'])),
+ ('five', make_spec(observables[1]['five'])),
+ ('six', make_spec(observables[1]['six'])),
+ ('seven', make_spec(observables[1]['seven'])),
+ ])
+ ))
+
+ actual_specs = observation_updater.observation_spec()
+ self.assertIs(type(actual_specs), type(expected_specs))
+ for actual_dict, expected_dict in zip(actual_specs, expected_specs):
+ self.assertIs(type(actual_dict), type(expected_dict))
+ self.assertEqual(actual_dict, expected_dict)
+
+ def make_value(obs):
+ value = obs(physics=None, random_state=None)
+ if obs.aggregator:
+ return value
+ else:
+ value = np.array(value)
+ value = value[np.newaxis, ...]
+ return value
+
+ expected_values = list_or_tuple((
+ {'two': make_value(observables[0]['two'])},
+ collections.OrderedDict([
+ ('three', make_value(observables[1]['three'])),
+ ('five', make_value(observables[1]['five'])),
+ ('six', make_value(observables[1]['six'])),
+ ('seven', make_value(observables[1]['seven'])),
+ ])
+ ))
+
+ actual_values = observation_updater.get_observation()
+ self.assertIs(type(actual_values), type(expected_values))
+ for actual_dict, expected_dict in zip(actual_values, expected_values):
+ self.assertIs(type(actual_dict), type(expected_dict))
+ self.assertLen(actual_dict, len(expected_dict))
+ for actual, expected in zip(actual_dict.items(), expected_dict.items()):
+ actual_name, actual_value = actual
+ expected_name, expected_value = expected
+ self.assertEqual(actual_name, expected_name)
+ np.testing.assert_array_equal(actual_value, expected_value)
+
+ def assertCorrectSpec(
+ self, spec, expected_shape, expected_dtype, expected_name):
+ self.assertEqual(spec.shape, expected_shape)
+ self.assertEqual(spec.dtype, expected_dtype)
+ self.assertEqual(spec.name, expected_name)
+
+ def testObservationSpecInference(self):
+ physics = fake_physics.FakePhysics()
+ physics.observables['repeated'].buffer_size = 5
+ physics.observables['matrix'].buffer_size = 4
+ physics.observables['sqrt'] = observable.Generic(
+ fake_physics.FakePhysics.sqrt, buffer_size=3)
+
+ for obs in physics.observables.values():
+ obs.enabled = True
+
+ observation_updater = updater.Updater(physics.observables)
+ observation_updater.reset(physics=physics, random_state=None)
+
+ spec = observation_updater.observation_spec()
+ self.assertCorrectSpec(spec['repeated'], (5, 2), int, 'repeated')
+ self.assertCorrectSpec(spec['matrix'], (4, 2, 3), int, 'matrix')
+ self.assertCorrectSpec(spec['sqrt'], (3,), float, 'sqrt')
+
+ def testCustomSpecTypePassedThrough(self):
+ physics = fake_physics.FakePhysics()
+ physics.observables['two_twos'] = GenericObservableWithMyArraySpec(
+ lambda _: [2.0, 2.0], buffer_size=3
+ )
+
+ physics.observables['two_twos'].enabled = True
+
+ observation_updater = updater.Updater(physics.observables)
+ observation_updater.reset(physics=physics, random_state=None)
+
+ spec = observation_updater.observation_spec()
+ self.assertIsInstance(spec['two_twos'], MyArraySpec)
+ self.assertEqual(spec['two_twos'].shape, (3, 2))
+ self.assertEqual(spec['two_twos'].dtype, float)
+ self.assertEqual(spec['two_twos'].name, 'two_twos')
+
+ @parameterized.parameters(True, False)
+ def testObservation(self, pad_with_initial_value):
+ physics = fake_physics.FakePhysics()
+ physics.observables['repeated'].buffer_size = 5
+ physics.observables['matrix'].delay = 1
+ physics.observables['sqrt_plus_one'] = observable.Generic(
+ fake_physics.FakePhysics.sqrt_plus_one, update_interval=7,
+ buffer_size=3, delay=2)
+ for obs in physics.observables.values():
+ obs.enabled = True
+ with physics.reset_context():
+ pass
+
+ physics_steps_per_control_step = 5
+ observation_updater = updater.Updater(
+ physics.observables, physics_steps_per_control_step,
+ pad_with_initial_value=pad_with_initial_value)
+ observation_updater.reset(physics=physics, random_state=None)
+
+ for control_step in range(0, 200):
+ observation_updater.prepare_for_next_control_step()
+ for _ in range(physics_steps_per_control_step):
+ physics.step()
+ observation_updater.update()
+
+ step_counter = (control_step + 1) * physics_steps_per_control_step
+
+ observation = observation_updater.get_observation()
+ def assert_correct_buffer(obs_name, expected_callable,
+ observation=observation,
+ step_counter=step_counter):
+ update_interval = (physics.observables[obs_name].update_interval
+ or updater.DEFAULT_UPDATE_INTERVAL)
+ buffer_size = (physics.observables[obs_name].buffer_size
+ or updater.DEFAULT_BUFFER_SIZE)
+ delay = (physics.observables[obs_name].delay
+ or updater.DEFAULT_DELAY)
+
+ # The final item in the buffer is the current time, less the delay,
+ # rounded _down_ to the nearest multiple of the update interval.
+ end = update_interval * int(
+ math.floor((step_counter - delay) / update_interval))
+
+ # Figure out the first item in the buffer by working backwards from
+ # the final item in multiples of the update interval.
+ start = end - (buffer_size - 1) * update_interval
+
+ # Clamp both the start and end step number below by zero.
+ buffer_range = range(max(0, start), max(0, end + 1), update_interval)
+
+ # Arrays with expected shapes, filled with expected default values.
+ expected_value_spec = observation_updater.observation_spec()[obs_name]
+ if pad_with_initial_value:
+ expected_values = np.full(shape=expected_value_spec.shape,
+ fill_value=expected_callable(0),
+ dtype=expected_value_spec.dtype)
+ else:
+ expected_values = np.zeros(shape=expected_value_spec.shape,
+ dtype=expected_value_spec.dtype)
+
+ # The arrays are filled from right to left, such that the most recent
+ # entry is the rightmost one, and any padding is on the left.
+ for index, timestamp in enumerate(reversed(buffer_range)):
+ expected_values[-(index+1)] = expected_callable(timestamp)
+
+ np.testing.assert_array_equal(observation[obs_name], expected_values)
+
+ assert_correct_buffer('twice', lambda x: 2*x)
+ assert_correct_buffer('matrix', lambda x: [[x]*3]*2)
+ assert_correct_buffer('repeated', lambda x: [x, x])
+ assert_correct_buffer('sqrt_plus_one', lambda x: np.sqrt(x) + 1)
+
+ def testVariableRatesAndDelays(self):
+ physics = fake_physics.FakePhysics()
+ physics.observables['time'] = observable.Generic(
+ lambda physics: physics.time(),
+ buffer_size=3,
+ # observations produced on step numbers 20*N + [0, 3, 5, 8, 11, 15, 16]
+ update_interval=DeterministicSequence([3, 2, 3, 3, 4, 1, 4]),
+ # observations arrive on step numbers 20*N + [3, 8, 7, 12, 11, 17, 20]
+ delay=DeterministicSequence([3, 5, 2, 5, 1, 2, 4]))
+ physics.observables['time'].enabled = True
+
+ physics_steps_per_control_step = 10
+ observation_updater = updater.Updater(
+ physics.observables, physics_steps_per_control_step)
+ observation_updater.reset(physics=physics, random_state=None)
+
+ # Run through a few cycles of the variation sequences to make sure that
+ # cross-control-boundary behaviour is correct.
+ for i in range(5):
+ observation_updater.prepare_for_next_control_step()
+ for _ in range(physics_steps_per_control_step):
+ physics.step()
+ observation_updater.update()
+ np.testing.assert_array_equal(
+ observation_updater.get_observation()['time'],
+ 20*i + np.array([0, 5, 3]))
+
+ observation_updater.prepare_for_next_control_step()
+ for _ in range(physics_steps_per_control_step):
+ physics.step()
+ observation_updater.update()
+ # Note that #11 is dropped since it arrives after #8,
+ # whose large delay caused it to cross the control step boundary at #10.
+ np.testing.assert_array_equal(
+ observation_updater.get_observation()['time'],
+ 20*i + np.array([8, 15, 16]))
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/robot.py b/dm_control/composer/robot.py
new file mode 100644
index 00000000..a4c4be50
--- /dev/null
+++ b/dm_control/composer/robot.py
@@ -0,0 +1,33 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining the abstract robot class."""
+
+import abc
+
+from dm_control.composer import entity
+import numpy as np
+
+DOWN_QUATERNION = np.array([0., 0.70710678118, 0.70710678118, 0.])
+
+
+class Robot(entity.Entity, metaclass=abc.ABCMeta):
+ """The abstract base class for robots."""
+
+ @property
+ @abc.abstractmethod
+ def actuators(self):
+ """Returns the actuator elements of the robot."""
+ raise NotImplementedError
diff --git a/dm_control/composer/task.py b/dm_control/composer/task.py
new file mode 100644
index 00000000..f24d13fb
--- /dev/null
+++ b/dm_control/composer/task.py
@@ -0,0 +1,322 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Abstract base class for a Composer task."""
+
+import abc
+import collections
+import copy
+
+from dm_control import mujoco
+from dm_env import specs
+
+
+def _check_timesteps_divisible(control_timestep, physics_timestep):
+ num_steps = control_timestep / physics_timestep
+ rounded_num_steps = int(round(num_steps))
+ if abs(num_steps - rounded_num_steps) > 1e-6:
+ raise ValueError(
+ 'Control timestep should be an integer multiple of physics timestep'
+ ': got {!r} and {!r}'.format(control_timestep, physics_timestep))
+ return rounded_num_steps
+
+
+class Task(metaclass=abc.ABCMeta):
+ """Abstract base class for a Composer task."""
+
+ @abc.abstractproperty
+ def root_entity(self):
+ """A `base.Entity` instance for this task."""
+ raise NotImplementedError
+
+ def iter_entities(self):
+ return self.root_entity.iter_entities()
+
+ @property
+ def observables(self):
+ """An OrderedDict of `control.Observable` instances for this task.
+
+ Task subclasses should generally NOT override this property.
+
+ This property is automatically computed by combining the observables dict
+ provided by each `Entity` present in this task, and any additional
+ observables returned via the `task_observables` property.
+
+ To provide an observable to an agent, the task code should either set
+ `enabled` property of an `Entity`-bound observable to `True`, or override
+ the `task_observables` property to provide additional observables not bound
+ to an `Entity`.
+
+ Returns:
+ An `collections.OrderedDict` mapping strings to instances of
+ `control.Observable`.
+ """
+ # Make a shallow copy of the OrderedDict, not the Observables themselves.
+ observables = copy.copy(self.task_observables)
+ for entity in self.root_entity.iter_entities():
+ observables.update(entity.observables.as_dict())
+ return observables
+
+ @property
+ def task_observables(self):
+ """An OrderedDict of task-specific `control.Observable` instances.
+
+ A task should override this property if it wants to provide additional
+ observables to the agent that are not already provided by any `Entity` that
+ forms part of the task's model. For example, this may be used to provide
+ observations that is derived from relative poses between two entities.
+
+ Returns:
+ An `collections.OrderedDict` mapping strings to instances of
+ `control.Observable`.
+ """
+ return collections.OrderedDict()
+
+ def after_compile(self, physics, random_state):
+ """A callback which is executed after the Mujoco Physics is recompiled.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def _check_root_entity(self, callee_name):
+ try:
+ _ = self.root_entity
+ except Exception as effect:
+ cause = RuntimeError(
+ f'call to `{callee_name}` made before `root_entity` is available')
+ raise effect from cause
+
+ @property
+ def control_timestep(self):
+ """Returns the agent's control timestep for this task (in seconds)."""
+ self._check_root_entity('control_timestep')
+ if hasattr(self, '_control_timestep'):
+ return self._control_timestep
+ else:
+ return self.physics_timestep
+
+ @control_timestep.setter
+ def control_timestep(self, new_value):
+ """Changes the agent's control timestep for this task.
+
+ Args:
+ new_value: the new control timestep (in seconds).
+
+ Raises:
+ ValueError: if `new_value` is set and is not divisible by
+ `physics_timestep`.
+ """
+ self._check_root_entity('control_timestep')
+ _check_timesteps_divisible(new_value, self.physics_timestep)
+ self._control_timestep = new_value
+
+ @property
+ def physics_timestep(self):
+ """Returns the physics timestep for this task (in seconds)."""
+ self._check_root_entity('physics_timestep')
+ if self.root_entity.mjcf_model.option.timestep is None:
+ return 0.002 # MuJoCo's default.
+ else:
+ return self.root_entity.mjcf_model.option.timestep
+
+ @physics_timestep.setter
+ def physics_timestep(self, new_value):
+ """Changes the physics simulation timestep for this task.
+
+ Args:
+ new_value: the new simulation timestep (in seconds).
+
+ Raises:
+ ValueError: if `control_timestep` is set and is not divisible by
+ `new_value`.
+ """
+ self._check_root_entity('physics_timestep')
+ if hasattr(self, '_control_timestep'):
+ _check_timesteps_divisible(self._control_timestep, new_value)
+ self.root_entity.mjcf_model.option.timestep = new_value
+
+ def set_timesteps(self, control_timestep, physics_timestep):
+ """Changes the agent's control timestep and physics simulation timestep.
+
+ This is equivalent to modifying `control_timestep` and `physics_timestep`
+ simultaneously. The divisibility check is performed between the two
+ new values.
+
+ Args:
+ control_timestep: the new agent's control timestep (in seconds).
+ physics_timestep: the new physics simulation timestep (in seconds).
+
+ Raises:
+ ValueError: if `control_timestep` is not divisible by `physics_timestep`.
+ """
+ self._check_root_entity('set_timesteps')
+ _check_timesteps_divisible(control_timestep, physics_timestep)
+ self.root_entity.mjcf_model.option.timestep = physics_timestep
+ self._control_timestep = control_timestep
+
+ @property
+ def physics_steps_per_control_step(self):
+ """Returns number of physics steps per agent's control step."""
+ return _check_timesteps_divisible(
+ self.control_timestep, self.physics_timestep)
+
+ def action_spec(self, physics):
+ """Returns a `BoundedArray` spec matching the `Physics` actuators.
+
+ BoundedArray.name should contain a tab-separated list of actuator names.
+ When overloading this method, non-MuJoCo actuators should be added to the
+ top of the list when possible, as a matter of convention.
+
+ Args:
+ physics: used to query actuator names in the model.
+ """
+ names = [physics.model.id2name(i, 'actuator') or str(i)
+ for i in range(physics.model.nu)]
+ action_spec = mujoco.action_spec(physics)
+ return specs.BoundedArray(shape=action_spec.shape,
+ dtype=action_spec.dtype,
+ minimum=action_spec.minimum,
+ maximum=action_spec.maximum,
+ name='\t'.join(names))
+
+ def get_reward_spec(self):
+ """Optional method to define non-scalar rewards for a `Task`."""
+ return None
+
+ def get_discount_spec(self):
+ """Optional method to define non-scalar discounts for a `Task`."""
+ return None
+
+ def initialize_episode_mjcf(self, random_state):
+ """Modifies the MJCF model of this task before the next episode begins.
+
+ The Environment calls this method and recompiles the physics
+ if necessary before calling `initialize_episode`.
+
+ Args:
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def initialize_episode(self, physics, random_state):
+ """Modifies the physics state before the next episode begins.
+
+ The Environment calls this method after `initialize_episode_mjcf`, and also
+ after the physics has been recompiled if necessary.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def before_step(self, physics, action, random_state):
+ """A callback which is executed before an agent control step.
+
+ The default implementation sets the control signal for the actuators in
+ `physics` to be equal to `action`. Subclasses that override this method
+ should ensure that the overriding method also sets the control signal before
+ returning, either by calling `super().before_step`, or by setting
+ the control signal explicitly (e.g. in order to create a non-trivial mapping
+ between `action` and the control signal).
+
+ Args:
+ physics: An instance of `control.Physics`.
+ action: A NumPy array corresponding to agent actions.
+ random_state: An instance of `np.random.RandomState` (unused).
+ """
+ del random_state # Unused.
+ physics.set_control(action)
+
+ def before_substep(self, physics, action, random_state):
+ """A callback which is executed before a simulation step.
+
+ Actuation can be set, or overridden, in this callback.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ action: A NumPy array corresponding to agent actions.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def after_substep(self, physics, random_state):
+ """A callback which is executed after a simulation step.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def after_step(self, physics, random_state):
+ """A callback which is executed after an agent control step.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_reward(self, physics):
+ """Calculates the reward signal given the physics state.
+
+ Args:
+ physics: A Physics object.
+
+ Returns:
+ A float
+ """
+ raise NotImplementedError
+
+ def should_terminate_episode(self, physics): # pylint: disable=unused-argument
+ """Determines whether the episode should terminate given the physics state.
+
+ Args:
+ physics: A Physics object
+
+ Returns:
+ A boolean
+ """
+ return False
+
+ def get_discount(self, physics): # pylint: disable=unused-argument
+ """Calculates the reward discount factor given the physics state.
+
+ Args:
+ physics: A Physics object
+
+ Returns:
+ A float
+ """
+ return 1.0
+
+
+class NullTask(Task):
+ """A class that wraps a single `Entity` into a `Task` with no reward."""
+
+ def __init__(self, root_entity):
+ self._root_entity = root_entity
+
+ @property
+ def root_entity(self):
+ return self._root_entity
+
+ def get_reward(self, physics):
+ return 0.0
diff --git a/dm_control/composer/variation/__init__.py b/dm_control/composer/variation/__init__.py
new file mode 100644
index 00000000..60e0f049
--- /dev/null
+++ b/dm_control/composer/variation/__init__.py
@@ -0,0 +1,136 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A module that helps manage model variation in Composer environments."""
+
+import collections
+import copy
+
+from dm_control.composer.variation.base import Variation
+from dm_control.composer.variation.variation_values import evaluate
+
+
+class _VariationInfo:
+
+ __slots__ = ['initial_value', 'variation']
+
+ def __init__(self, initial_value=None, variation=None):
+ self.initial_value = initial_value
+ self.variation = variation
+
+
+class MJCFVariator:
+ """Helper object for applying variations to MJCF attributes.
+
+ An instance of this class remembers the original value of each MJCF attribute
+ the first time a variation is applied. The original value is then passed as an
+ argument to each variation callable.
+ """
+
+ def __init__(self):
+ self._variations = collections.defaultdict(dict)
+
+ def bind_attributes(self, element, **kwargs):
+ """Binds variations to attributes of an MJCF element.
+
+ Args:
+ element: An `mjcf.Element` object.
+ **kwargs: Keyword arguments mapping attribute names to the corresponding
+ variations. A variation is either a fixed value or a callable that
+ optionally takes the original value of an attribute and returns a
+ new value.
+ """
+ for attribute_name, variation in kwargs.items():
+ if variation is None and attribute_name in self._variations[element]:
+ del self._variations[element][attribute_name]
+ else:
+ initial_value = copy.copy(getattr(element, attribute_name))
+ self._variations[element][attribute_name] = (
+ _VariationInfo(initial_value, variation))
+
+ def apply_variations(self, random_state):
+ """Applies variations in-place to the specified MJCF element.
+
+ Args:
+ random_state: A `numpy.random.RandomState` instance.
+ """
+ for element, attribute_variations in self._variations.items():
+ new_values = {}
+ for attribute_name, variation_info in attribute_variations.items():
+ current_value = getattr(element, attribute_name)
+ if variation_info.initial_value is None:
+ variation_info.initial_value = copy.copy(current_value)
+ new_values[attribute_name] = evaluate(
+ variation_info.variation, variation_info.initial_value,
+ current_value, random_state)
+ element.set_attributes(**new_values)
+
+ def clear(self):
+ """Clears all bound attribute variations."""
+ self._variations.clear()
+
+ def reset_initial_values(self):
+ for variations in self._variations.values():
+ for variation_info in variations.values():
+ variation_info.initial_value = None
+
+
+class PhysicsVariator:
+ """Helper object for applying variations to MjModel and MjData.
+
+ An instance of this class remembers the original value of each attribute
+ the first time a variation is applied. The original value is then passed as an
+ argument to each variation callable.
+ """
+
+ def __init__(self):
+ self._variations = collections.defaultdict(dict)
+
+ def bind_attributes(self, element, **kwargs):
+ """Binds variations to attributes of an MJCF element.
+
+ Args:
+ element: An `mjcf.Element` object.
+ **kwargs: Keyword arguments mapping attribute names to the corresponding
+ variations. A variation is either a fixed value or a callable that
+ optionally takes the original value of an attribute and returns a
+ new value.
+ """
+ for attribute_name, variation in kwargs.items():
+ if variation is None and attribute_name in self._variations[element]:
+ del self._variations[element][attribute_name]
+ else:
+ self._variations[element][attribute_name] = (
+ _VariationInfo(None, variation))
+
+ def apply_variations(self, physics, random_state):
+ for element, variations in self._variations.items():
+ binding = physics.bind(element)
+ for attribute_name, variation_info in variations.items():
+ current_value = getattr(binding, attribute_name)
+ if variation_info.initial_value is None:
+ variation_info.initial_value = copy.copy(current_value)
+ setattr(binding, attribute_name, evaluate(
+ variation_info.variation, variation_info.initial_value,
+ current_value, random_state))
+
+ def clear(self):
+ """Clears all bound attribute variations."""
+ self._variations.clear()
+
+ def reset_initial_values(self):
+ for variations in self._variations.values():
+ for variation_info in variations.values():
+ variation_info.initial_value = None
diff --git a/dm_control/composer/variation/base.py b/dm_control/composer/variation/base.py
new file mode 100644
index 00000000..ea3d1fa8
--- /dev/null
+++ b/dm_control/composer/variation/base.py
@@ -0,0 +1,172 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for variations and binary operations on variations."""
+
+import abc
+import operator
+
+from dm_control.composer.variation import variation_values
+import numpy as np
+
+
+class Variation(metaclass=abc.ABCMeta):
+ """Abstract base class for variations."""
+
+ @abc.abstractmethod
+ def __call__(self, initial_value, current_value, random_state):
+ """Generates a value for this variation.
+
+ Args:
+ initial_value: The original value of the attribute being varied.
+ Absolute variations may ignore this argument.
+ current_value: The current value of the attribute being varied.
+ Absolute variations may ignore this argument.
+ random_state: A `numpy.RandomState` used to generate the value.
+ Deterministic variations may ignore this argument.
+
+ Returns:
+ The next value for this variation.
+ """
+
+ def __add__(self, other):
+ return _BinaryOperation(operator.add, self, other)
+
+ def __radd__(self, other):
+ return _BinaryOperation(operator.add, other, self)
+
+ def __sub__(self, other):
+ return _BinaryOperation(operator.sub, self, other)
+
+ def __rsub__(self, other):
+ return _BinaryOperation(operator.sub, other, self)
+
+ def __mul__(self, other):
+ return _BinaryOperation(operator.mul, self, other)
+
+ def __rmul__(self, other):
+ return _BinaryOperation(operator.mul, other, self)
+
+ def __truediv__(self, other):
+ return _BinaryOperation(operator.truediv, self, other)
+
+ def __rtruediv__(self, other):
+ return _BinaryOperation(operator.truediv, other, self)
+
+ def __floordiv__(self, other):
+ return _BinaryOperation(operator.floordiv, self, other)
+
+ def __rfloordiv__(self, other):
+ return _BinaryOperation(operator.floordiv, other, self)
+
+ def __pow__(self, other):
+ return _BinaryOperation(operator.pow, self, other)
+
+ def __rpow__(self, other):
+ return _BinaryOperation(operator.pow, other, self)
+
+ def __getitem__(self, index):
+ return _GetItemOperation(self, index)
+
+ def __neg__(self):
+ return _UnaryOperation(operator.neg, self)
+
+
+class _UnaryOperation(Variation):
+ """Represents the result of applying a unary operator to a Variation."""
+
+ def __init__(self, op, variation):
+ self._op = op
+ self._variation = variation
+
+ def __eq__(self, other):
+ if not isinstance(other, _UnaryOperation):
+ return False
+ return self._op == other._op and self._variation == other._variation
+
+ def __str__(self):
+ return f"{self._op.__name__}({self._variation})"
+
+ def __repr__(self):
+ return f"UnaryOperation({self._op.__name__}({self._variation}))"
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ value = variation_values.evaluate(
+ self._variation, initial_value, current_value, random_state
+ )
+ return self._op(value)
+
+
+class _BinaryOperation(Variation):
+ """Represents the result of applying a binary operator to two Variations."""
+
+ def __init__(self, op, first, second):
+ self._first = first
+ self._second = second
+ self._op = op
+
+ def __eq__(self, other):
+ if not isinstance(other, _BinaryOperation):
+ return False
+ return (
+ self._op == other._op
+ and self._first == other._first
+ and self._second == other._second
+ )
+
+ def __str__(self):
+ return f"{self._op.__name__}({self._first}, {self._second})"
+
+ def __repr__(self):
+ return (
+ f"BinaryOperation({self._op.__name__}({self._first!r},"
+ f" {self._second!r}))"
+ )
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ first_value = variation_values.evaluate(
+ self._first, initial_value, current_value, random_state
+ )
+ second_value = variation_values.evaluate(
+ self._second, initial_value, current_value, random_state
+ )
+ return self._op(first_value, second_value)
+
+
+class _GetItemOperation(Variation):
+ """Returns a single element from the output of a Variation."""
+
+ def __init__(self, variation, index):
+ self._variation = variation
+ self._index = index
+
+ def __eq__(self, other):
+ if not isinstance(other, _GetItemOperation):
+ return False
+ return self._variation == other._variation and self._index == other._index
+
+ def __str__(self):
+ return f"{self._variation}[{self._index}]"
+
+ def __repr__(self):
+ return (
+ f"GetItemOperation({self._variation!r}[{self._index}])"
+ )
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ value = variation_values.evaluate(
+ self._variation, initial_value, current_value, random_state
+ )
+ return np.asarray(value)[self._index]
diff --git a/dm_control/composer/variation/colors.py b/dm_control/composer/variation/colors.py
new file mode 100644
index 00000000..2c6bbe80
--- /dev/null
+++ b/dm_control/composer/variation/colors.py
@@ -0,0 +1,106 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Variations in colors.
+
+Classes in this module allow users to specify a variations for each channel in
+a variety of color spaces. The generated values are always RGBA arrays.
+"""
+
+import colorsys
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation import variation_values
+import numpy as np
+
+
+class RgbVariation(base.Variation):
+ """Represents a variation in the RGB color space.
+
+ This class allows users to specify independent variations in the R, G, B, and
+ alpha channels of a color, and generates the corresponding array of RGBA
+ values.
+ """
+
+ def __init__(self, r, g, b, alpha=1.0):
+ self._r, self._g, self._b = r, g, b
+ self._alpha = alpha
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ return np.asarray(
+ variation_values.evaluate([self._r, self._g, self._b, self._alpha],
+ initial_value, current_value, random_state))
+
+ def __eq__(self, other):
+ if not isinstance(other, RgbVariation):
+ return False
+ return (
+ self._r == other._r
+ and self._g == other._g
+ and self._b == other._b
+ and self._alpha == other._alpha
+ )
+
+
+class HsvVariation(base.Variation):
+ """Represents a variation in the HSV color space.
+
+ This class allows users to specify independent variations in the H, S, V, and
+ alpha channels of a color, and generates the corresponding array of RGBA
+ values.
+ """
+
+ def __init__(self, h, s, v, alpha=1.0):
+ self._h, self._s, self._v = h, s, v
+ self._alpha = alpha
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ h, s, v, alpha = variation_values.evaluate(
+ (self._h, self._s, self._v, self._alpha), initial_value, current_value,
+ random_state)
+ return np.asarray(list(colorsys.hsv_to_rgb(h, s, v)) + [alpha])
+
+ def __eq__(self, other):
+ if not isinstance(other, HsvVariation):
+ return False
+ return (
+ self._h == other._h
+ and self._s == other._s
+ and self._v == other._v
+ and self._alpha == other._alpha
+ )
+
+ def __repr__(self):
+ return (
+ f"HsvVariation(h={self._h}, s={self._s}, v={self._v}, "
+ f"alpha={self._alpha})"
+ )
+
+
+class GrayVariation(HsvVariation):
+ """Represents a variation in gray level.
+
+ This class allows users to specify independent variations in the gray level
+ and alpha channels of a color, and generates the corresponding array of RGBA
+ values.
+ """
+
+ def __init__(self, gray_level, alpha=1.0):
+ super().__init__(h=0.0, s=0.0, v=gray_level, alpha=alpha)
+
+ def __repr__(self):
+ return (
+ f"GrayVariation(gray_level={self._v}, alpha={self._alpha})"
+ )
diff --git a/dm_control/composer/variation/deterministic.py b/dm_control/composer/variation/deterministic.py
new file mode 100644
index 00000000..a6ad4cba
--- /dev/null
+++ b/dm_control/composer/variation/deterministic.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Deterministic variations."""
+
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation.variation_values import evaluate
+
+
+class Constant(base.Variation):
+ """Wraps a constant value into a Variation object.
+
+ This class is provided mainly for use in tests, to check that variations are
+ invoked correctly without having to introduce randomness in test cases.
+ """
+
+ def __init__(self, value):
+ self._value = value
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ return self._value
+
+ def __eq__(self, other):
+ if not isinstance(other, Constant):
+ return False
+ return self._value == other._value
+
+ def __str__(self):
+ return f"{self._value}"
+
+ def __repr__(self):
+ return f"Constant({self._value!r})"
+
+
+class Sequence(base.Variation):
+ """Variation representing a fixed sequence of values."""
+
+ def __init__(self, values):
+ self._values = values
+ self._iterator = iter(self._values)
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ try:
+ return evaluate(next(self._iterator), initial_value=initial_value,
+ current_value=current_value, random_state=random_state)
+ except StopIteration:
+ self._iterator = iter(self._values)
+ return evaluate(next(self._iterator), initial_value=initial_value,
+ current_value=current_value, random_state=random_state)
+
+
+class Identity(base.Variation):
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ return current_value
+
+ def __eq__(self, other):
+ return isinstance(other, Identity)
diff --git a/dm_control/composer/variation/distributions.py b/dm_control/composer/variation/distributions.py
new file mode 100644
index 00000000..f47e88fc
--- /dev/null
+++ b/dm_control/composer/variation/distributions.py
@@ -0,0 +1,258 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Standard statistical distributions that conform to the Variation API."""
+
+import abc
+import functools
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation.variation_values import evaluate
+import numpy as np
+
+
+class Distribution(base.Variation, metaclass=abc.ABCMeta):
+ """Base Distribution class for sampling a parametrized distribution.
+
+ Subclasses need to implement `_callable`, which needs to return a callable
+ based on the random_state passed as arg. This callable then gets called using
+ the arguments passed to the constructor, after being evaluated. This allows
+ the distribution parameters themselves to be instances of `base.Variation`.
+ By default samples are drawn in the shape of `initial_value`, unless the
+ optional `single_sample` constructor arg is set to `True`, in which case only
+ a single sample is drawn.
+ """
+ __slots__ = ('_single_sample', '_args', '_kwargs')
+
+ def __init__(self, *args, **kwargs):
+ self._single_sample = kwargs.pop('single_sample', False)
+ self._args = args
+ self._kwargs = kwargs
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ local_random_state = random_state or np.random
+ size = (
+ None if self._single_sample or initial_value is None # pylint: disable=g-long-ternary
+ else np.shape(initial_value))
+ local_args = evaluate(
+ self._args,
+ initial_value=initial_value,
+ current_value=current_value,
+ random_state=random_state)
+ local_kwargs = evaluate(
+ self._kwargs,
+ initial_value=initial_value,
+ current_value=current_value,
+ random_state=random_state)
+ return self._callable(local_random_state)(
+ *local_args, size=size, **local_kwargs)
+
+ def __getattr__(self, name):
+ if name.startswith('__'):
+ raise AttributeError # Stops infinite recursion during deepcopy.
+ elif name in self._kwargs:
+ return self._kwargs[name]
+ else:
+ raise AttributeError('{!r} object has no attribute {!r}'.format(
+ type(self).__name__, name))
+
+ @abc.abstractmethod
+ def _callable(self, random_state):
+ raise NotImplementedError
+
+ def __eq__(self, other):
+ if not isinstance(other, type(self)):
+ return False
+ return (
+ self._args == other._args
+ and self._kwargs == other._kwargs
+ and self._single_sample == other._single_sample
+ )
+
+ def __repr__(self):
+ return '{}(args={}, kwargs={}, single_sample={})'.format(
+ type(self).__name__,
+ self._args,
+ self._kwargs,
+ self._single_sample)
+
+
+class Uniform(Distribution):
+ __slots__ = ()
+
+ def __init__(self, low=0.0, high=1.0, single_sample=False):
+ super().__init__(low=low, high=high, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.uniform
+
+
+class UniformInteger(Distribution):
+ __slots__ = ()
+
+ def __init__(self, low, high=None, single_sample=False):
+ super().__init__(low, high=high, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.randint
+
+
+class UniformChoice(Distribution):
+ __slots__ = ()
+
+ def __init__(self, choices, single_sample=False):
+ super().__init__(choices, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.choice
+
+
+class UniformPointOnSphere(base.Variation):
+ """Samples a point on the unit sphere, i.e. a 3D vector with norm 1."""
+ __slots__ = ()
+
+ def __init__(self, single_sample=False):
+ self._single_sample = single_sample
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ size = (
+ 3 if self._single_sample or initial_value is None # pylint: disable=g-long-ternary
+ else np.append(np.shape(initial_value), 3))
+ axis = random_state.normal(size=size)
+ axis /= np.linalg.norm(axis, axis=-1, keepdims=True)
+ return axis
+
+ def __eq__(self, other):
+ if not isinstance(other, UniformPointOnSphere):
+ return False
+ return self._single_sample == other._single_sample
+
+ def __repr__(self):
+ return '{}(single_sample={})'.format(
+ type(self).__name__,
+ self._single_sample,
+ )
+
+
+class Normal(Distribution):
+ __slots__ = ()
+
+ def __init__(self, loc=0.0, scale=1.0, single_sample=False):
+ super().__init__(loc=loc, scale=scale, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.normal
+
+
+class LogNormal(Distribution):
+ __slots__ = ()
+
+ def __init__(self, mean=0.0, sigma=1.0, single_sample=False):
+ super().__init__(mean=mean, sigma=sigma, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.lognormal
+
+
+class Exponential(Distribution):
+ __slots__ = ()
+
+ def __init__(self, scale=1.0, single_sample=False):
+ super().__init__(scale=scale, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.exponential
+
+
+class Poisson(Distribution):
+ __slots__ = ()
+
+ def __init__(self, lam=1.0, single_sample=False):
+ super().__init__(lam=lam, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.poisson
+
+
+class Bernoulli(Distribution):
+ __slots__ = ()
+
+ def __init__(self, prob=0.5, single_sample=False):
+ super().__init__(prob, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return functools.partial(random_state.binomial, 1)
+
+
+_NEGATIVE_STDEV = '`stdev` must be >= 0, got {}.'
+_NEGATIVE_TIMESCALE = '`timescale` must be >= 0, got {}.'
+
+
+class BiasedRandomWalk(base.Variation):
+ """A Class for generating noise from a zero-mean Ornstein-Uhlenbeck process.
+
+ Let
+ `retain = np.exp(-1. / timescale)`
+ and
+ `scale = stdev * sqrt(1 - (retain * retain))`
+ Then the discete-time first-order filtered diffusion process
+ `x_next = retain * x + N(0, scale))`
+ has standard deviation `stdev` and characteristic timescale `timescale`.
+ """
+ __slots__ = ('_scale', '_value')
+
+ def __init__(self, stdev=0.1, timescale=10.):
+ """Initializes a `BiasedRandomWalk`.
+
+ Args:
+ stdev: Float. Standard deviation of the output sequence.
+ timescale: Integer. Number of timesteps characteristic of the random walk.
+ After `timescale` steps the correlation is reduced by exp(-1). Larger or
+ equal to 0, where a value of 0 is an uncorrelated normal distribution.
+
+ Raises:
+ ValueError: if either `stdev` or `timescale` is negative.
+ """
+ if stdev < 0:
+ raise ValueError(_NEGATIVE_STDEV.format(stdev))
+ if timescale < 0:
+ raise ValueError(_NEGATIVE_TIMESCALE.format(timescale))
+ elif timescale == 0:
+ self._retain = 0.
+ else:
+ self._retain = np.exp(-1. / timescale)
+ self._scale = stdev * np.sqrt(1 - (self._retain * self._retain))
+ self._value = 0.0
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ self._value = (
+ self._retain * self._value +
+ random_state.normal(loc=0.0, scale=self._scale))
+ return self._value
+
+ def __eq__(self, other):
+ # __eq__ shouldn't be used for this one, because it's stateful.
+ return id(self) == id(other)
+
+ def __repr__(self):
+ # include id(self), to make sure that two instances with the same parameters
+ # don't appear equal in logs.
+ return '{}(id={}, scale={}, retain={}, value={})'.format(
+ type(self).__name__,
+ id(self),
+ self._scale,
+ self._retain,
+ self._value)
diff --git a/dm_control/composer/variation/distributions_test.py b/dm_control/composer/variation/distributions_test.py
new file mode 100644
index 00000000..459e20f6
--- /dev/null
+++ b/dm_control/composer/variation/distributions_test.py
@@ -0,0 +1,140 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.variation import distributions
+import numpy as np
+
+RANDOM_SEED = 123
+NUM_ITERATIONS = 100
+
+
+def _make_random_state():
+ return np.random.RandomState(RANDOM_SEED)
+
+
+class DistributionsTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._variation_random_state = _make_random_state()
+ self._np_random_state = _make_random_state()
+
+ def testUniform(self):
+ lower, upper = [2, 3, 4], [5, 6, 7]
+ variation = distributions.Uniform(low=lower, high=upper)
+ for _ in range(NUM_ITERATIONS):
+ np.testing.assert_array_equal(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.uniform(lower, upper))
+
+ self.assertEqual(variation, distributions.Uniform(low=lower, high=upper))
+ self.assertNotEqual(variation, distributions.Uniform(low=upper, high=upper))
+ self.assertIn('[2, 3, 4]', repr(variation))
+
+ def testUniformChoice(self):
+ choices = ['apple', 'banana', 'cherry']
+ variation = distributions.UniformChoice(choices)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.choice(choices))
+
+ self.assertIn('banana', repr(variation))
+
+ def testUniformPointOnSphere(self):
+ variation = distributions.UniformPointOnSphere()
+ samples = []
+ for _ in range(NUM_ITERATIONS):
+ sample = variation(random_state=self._variation_random_state)
+ self.assertEqual(sample.size, 3)
+ np.testing.assert_approx_equal(np.linalg.norm(sample), 1.0)
+ samples.append(sample)
+ # Make sure that none of the samples are the same.
+ self.assertLen(set(np.reshape(np.asarray(samples), -1)), 3 * NUM_ITERATIONS)
+ self.assertEqual(variation, distributions.UniformPointOnSphere())
+ self.assertNotEqual(
+ variation, distributions.UniformPointOnSphere(single_sample=True)
+ )
+
+ def testNormal(self):
+ loc, scale = 1, 2
+ variation = distributions.Normal(loc=loc, scale=scale)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.normal(loc, scale))
+ self.assertEqual(variation, distributions.Normal(loc=loc, scale=scale))
+ self.assertNotEqual(
+ variation, distributions.Normal(loc=loc*2, scale=scale)
+ )
+ self.assertEqual(
+ "Normal(args=(), kwargs={'loc': 1, 'scale': 2}, single_sample=False)",
+ repr(variation),
+ )
+
+ def testExponential(self):
+ scale = 3
+ variation = distributions.Exponential(scale=scale)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.exponential(scale))
+ self.assertEqual(variation, distributions.Exponential(scale=scale))
+ self.assertNotEqual(
+ variation, distributions.Exponential(scale=scale*2)
+ )
+ self.assertEqual(
+ "Exponential(args=(), kwargs={'scale': 3}, single_sample=False)",
+ repr(variation),
+ )
+
+ def testPoisson(self):
+ lam = 4
+ variation = distributions.Poisson(lam=lam)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.poisson(lam))
+ self.assertEqual(variation, distributions.Poisson(lam=lam))
+ self.assertNotEqual(
+ variation, distributions.Poisson(lam=lam*2)
+ )
+ self.assertEqual(
+ "Poisson(args=(), kwargs={'lam': 4}, single_sample=False)",
+ repr(variation),
+ )
+
+ @parameterized.parameters(0, 10)
+ def testBiasedRandomWalk(self, timescale):
+ stdev = 1.
+ variation = distributions.BiasedRandomWalk(stdev=stdev, timescale=timescale)
+ sequence = [variation(random_state=self._variation_random_state)
+ for _ in range(int(max(timescale, 1)*NUM_ITERATIONS*1000))]
+ self.assertAlmostEqual(np.mean(sequence), 0., delta=0.01)
+ self.assertAlmostEqual(np.std(sequence), stdev, delta=0.01)
+
+ @parameterized.parameters(
+ dict(arg_name='stdev', template=distributions._NEGATIVE_STDEV),
+ dict(arg_name='timescale', template=distributions._NEGATIVE_TIMESCALE))
+ def testBiasedRandomWalkExceptions(self, arg_name, template):
+ bad_value = -1.
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, template.format(bad_value)):
+ _ = distributions.BiasedRandomWalk(**{arg_name: bad_value})
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/variation/math.py b/dm_control/composer/variation/math.py
new file mode 100644
index 00000000..2b79ffeb
--- /dev/null
+++ b/dm_control/composer/variation/math.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Math operations on variation objects."""
+
+import abc
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation.variation_values import evaluate
+
+import numpy as np
+
+
+class MathOp(base.Variation):
+ """Base MathOp class for applying math operations on variation objects.
+
+ Subclasses need to implement `_callable`, which takes in a single value and
+ applies the desired math operation. This operation gets applied to the result
+ of the evaluated base variation object passed at construction. Structured
+ variation objects are automatically traversed.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self._args = args
+ self._kwargs = kwargs
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ local_args = evaluate(
+ self._args,
+ initial_value=initial_value,
+ current_value=current_value,
+ random_state=random_state)
+ local_kwargs = evaluate(
+ self._kwargs,
+ initial_value=initial_value,
+ current_value=current_value,
+ random_state=random_state)
+ return self._callable(*local_args, **local_kwargs)
+
+ @property
+ @abc.abstractmethod
+ def _callable(self):
+ pass
+
+ def __eq__(self, other):
+ if not isinstance(other, type(self)):
+ return False
+ return (
+ self._args == other._args
+ and self._kwargs == other._kwargs
+ )
+
+ def __repr__(self):
+ return '{}(args={}, kwargs={})'.format(
+ type(self).__name__,
+ self._args,
+ self._kwargs,
+ )
+
+
+class Log(MathOp):
+
+ @property
+ def _callable(self):
+ return np.log
+
+
+class Max(MathOp):
+
+ @property
+ def _callable(self):
+ return np.max
+
+
+class Min(MathOp):
+
+ @property
+ def _callable(self):
+ return np.min
+
+
+class Norm(MathOp):
+
+ @property
+ def _callable(self):
+ return np.linalg.norm
diff --git a/dm_control/composer/variation/noises.py b/dm_control/composer/variation/noises.py
new file mode 100644
index 00000000..d84f187c
--- /dev/null
+++ b/dm_control/composer/variation/noises.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Meta-variations that modify original values by a specified variation."""
+
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation import variation_values
+
+
+class Additive(base.Variation):
+ """A variation that adds to an existing value.
+
+ This variation takes a value generated by another variation and adds it to an
+ existing value. In cumulative mode, the generated value is added to the
+ current value being varied. In non-cumulative mode, the generated value is
+ added to a fixed initial value.
+ """
+
+ def __init__(self, variation, cumulative=False):
+ self._variation = variation
+ self._cumulative = cumulative
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ base_value = current_value if self._cumulative else initial_value
+ return base_value + (
+ variation_values.evaluate(self._variation, initial_value, current_value,
+ random_state))
+
+ def __eq__(self, other):
+ if not isinstance(other, Additive):
+ return False
+ return (
+ self._variation == other._variation
+ and self._cumulative == other._cumulative
+ )
+
+ def __repr__(self):
+ return (
+ f"Additive(variation={self._variation}, cumulative={self._cumulative})"
+ )
+
+
+class Multiplicative(base.Variation):
+ """A variation that multiplies to an existing value.
+
+ This variation takes a value generated by another variation and multiplies it
+ to an existing value. In cumulative mode, the generated value is multiplied to
+ the current value being varied. In non-cumulative mode, the generated value is
+ multiplied to a fixed initial value.
+ """
+
+ def __init__(self, variation, cumulative=False):
+ self._variation = variation
+ self._cumulative = cumulative
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ base_value = current_value if self._cumulative else initial_value
+ return base_value * (
+ variation_values.evaluate(self._variation, initial_value, current_value,
+ random_state))
+
+ def __eq__(self, other):
+ if not isinstance(other, Multiplicative):
+ return False
+ return (
+ self._variation == other._variation
+ and self._cumulative == other._cumulative
+ )
+
+ def __repr__(self):
+ return (
+ f"Multiplicative(variation={self._variation}, "
+ f"cumulative={self._cumulative})"
+ )
diff --git a/dm_control/composer/variation/noises_test.py b/dm_control/composer/variation/noises_test.py
new file mode 100644
index 00000000..458b79f2
--- /dev/null
+++ b/dm_control/composer/variation/noises_test.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for noises."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.variation import deterministic
+from dm_control.composer.variation import noises
+
+NUM_ITERATIONS = 100
+
+
+class NoisesTest(parameterized.TestCase):
+
+ @parameterized.parameters(False, True)
+ def testAdditive(self, use_constant_variation_object):
+ amount = 2
+ if use_constant_variation_object:
+ variation = noises.Additive(deterministic.Constant(amount))
+ else:
+ variation = noises.Additive(amount)
+ initial_value = 0
+ current_value = initial_value
+ for _ in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value + amount)
+
+ @parameterized.parameters(False, True)
+ def testAdditiveCumulative(self, use_constant_variation_object):
+ amount = 3
+ if use_constant_variation_object:
+ variation = noises.Additive(
+ deterministic.Constant(amount), cumulative=True)
+ else:
+ variation = noises.Additive(amount, cumulative=True)
+ initial_value = 1
+ current_value = initial_value
+ for i in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value + amount * (i + 1))
+
+ @parameterized.parameters(False, True)
+ def testMultiplicative(self, use_constant_variation_object):
+ amount = 23
+ if use_constant_variation_object:
+ variation = noises.Multiplicative(deterministic.Constant(amount))
+ else:
+ variation = noises.Multiplicative(amount)
+ initial_value = 3
+ current_value = initial_value
+ for _ in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value * amount)
+
+ @parameterized.parameters(False, True)
+ def testMultiplicativeCumulative(self, use_constant_variation_object):
+ amount = 2
+ if use_constant_variation_object:
+ variation = noises.Multiplicative(
+ deterministic.Constant(amount), cumulative=True)
+ else:
+ variation = noises.Multiplicative(amount, cumulative=True)
+ initial_value = 3
+ current_value = initial_value
+ for i in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value * amount ** (i + 1))
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/variation/rotations.py b/dm_control/composer/variation/rotations.py
new file mode 100644
index 00000000..d242eb54
--- /dev/null
+++ b/dm_control/composer/variation/rotations.py
@@ -0,0 +1,146 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Variations in 3D rotations."""
+
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation import variation_values
+from dm_control.utils import transformations
+import numpy as np
+
+IDENTITY_QUATERNION = np.array([1., 0., 0., 0.])
+
+
+class UniformQuaternion(base.Variation):
+ """Uniformly distributed unit quaternions."""
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ u1, u2, u3 = random_state.uniform([0.] * 3, [1., 2. * np.pi, 2. * np.pi])
+ return np.array([np.sqrt(1. - u1) * np.sin(u2),
+ np.sqrt(1. - u1) * np.cos(u2),
+ np.sqrt(u1) * np.sin(u3),
+ np.sqrt(u1) * np.cos(u3)])
+
+ def __eq__(self, other):
+ return isinstance(other, UniformQuaternion)
+
+ def __repr__(self):
+ return "UniformQuaternion()"
+
+
+class QuaternionFromAxisAngle(base.Variation):
+ """Quaternion variation specified in terms of variations in axis and angle."""
+
+ def __init__(self, axis, angle):
+ self._axis = axis
+ self._angle = angle
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ axis = variation_values.evaluate(
+ self._axis, initial_value, current_value, random_state)
+ angle = variation_values.evaluate(
+ self._angle, initial_value, current_value, random_state)
+ return transformations.axisangle_to_quat(np.asarray(axis) * angle)
+
+ def __eq__(self, other):
+ if not isinstance(other, QuaternionFromAxisAngle):
+ return False
+ return (
+ self._axis == other._axis
+ and self._angle == other._angle
+ )
+
+ def __repr__(self):
+ return (
+ f"QuaternionFromAxisAngle(axis={self._axis}, angle={self._angle})"
+ )
+
+
+class QuaternionPreMultiply(base.Variation):
+ """A variation that pre-multiplies an existing quaternion value.
+
+ This variation takes a quaternion value generated by another variation and
+ pre-multiplies it to an existing value. In cumulative mode, the new quaternion
+ is pre-multiplied to the current value being varied. In non-cumulative mode,
+ the new quaternion is pre-multiplied to a fixed initial value.
+ """
+
+ def __init__(self, quat, cumulative=False):
+ self._quat = quat
+ self._cumulative = cumulative
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ q1 = variation_values.evaluate(self._quat, initial_value, current_value,
+ random_state)
+ q2 = current_value if self._cumulative else initial_value
+ return transformations.quat_mul(np.asarray(q1), np.asarray(q2))
+
+ def __eq__(self, other):
+ if not isinstance(other, QuaternionPreMultiply):
+ return False
+ return self._quat == other._quat and self._cumulative == other._cumulative
+
+ def __repr__(self):
+ return (
+ f"QuaternionPreMultiply(quat={self._quat},"
+ f" cumulative={self._cumulative})"
+ )
+
+
+class QuaternionRotate(base.Variation):
+ """Variation that rotates a given vector by the given quaternion.
+
+ The vector can either be an existing value passed at evaluation, or specified
+ as a separate variation at construction. In the former case, cumulative mode
+ determines whether to use the current or initial value of the vector. The#
+ quaternion is always specified by a variation at construction.
+ """
+
+ def __init__(self, quat, vec=None, cumulative=False):
+ self._quat = quat
+ self._vec = vec
+ self._cumulative = cumulative
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ quat = variation_values.evaluate(
+ self._quat, initial_value, current_value, random_state
+ )
+ if self._vec is None:
+ vec = current_value if self._cumulative else initial_value
+ else:
+ vec = variation_values.evaluate(
+ self._vec, initial_value, current_value, random_state
+ )
+ return transformations.quat_rotate(np.asarray(quat), np.asarray(vec))
+
+ def __eq__(self, other):
+ if not isinstance(other, QuaternionRotate):
+ return False
+ return (
+ self._quat == other._quat
+ and self._vec == other._vec
+ and self._cumulative == other._cumulative
+ )
+
+ def __repr__(self):
+ return (
+ f"QuaternionRotate(quat={self._quat}, vec={self._vec},"
+ f" cumulative={self._cumulative})"
+ )
diff --git a/dm_control/composer/variation/variation_broadcaster.py b/dm_control/composer/variation/variation_broadcaster.py
new file mode 100644
index 00000000..3f5a1817
--- /dev/null
+++ b/dm_control/composer/variation/variation_broadcaster.py
@@ -0,0 +1,65 @@
+# Copyright 2024 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A broadcaster that allows sharing of variation values across many callers."""
+
+import collections
+import weakref
+
+from dm_control.composer import variation
+
+
+class VariationBroadcaster:
+ """Allows a variation to be broadcasted to multiple callers.
+
+ This class wraps a `Variation` object and generates multiple proxies that
+ can be used in place of the wrapped `Variation`. The broadcaster updates its
+ value in rounds. At the beginning of each round, the broadcaster re-evaluates
+ the wrapped `Variation` and caches the new value internally. When a proxy
+ is called, the broadcaster will return this cached value, thus ensuring that
+ all proxied values are the same. The round ends when all of the proxies have
+ been called exactly once. It is an error to call any particular proxy more
+ than once per round.
+ """
+
+ def __init__(self, wrapped_variation: variation.Variation):
+ self._wrapped_variation = wrapped_variation
+ self._cached_values = weakref.WeakKeyDictionary()
+
+ def get_proxy(self) -> variation.Variation:
+ """Returns a `Variation` to be used in place of the wrapped `Variation`."""
+ new_proxy = _BroadcastedValueProxy(self)
+ self._cached_values[new_proxy] = collections.deque()
+ return new_proxy
+
+ def _get_value(self, proxy, random_state):
+ """Returns the variation value for a proxy owned by this broadcaster."""
+ cached_values = self._cached_values[proxy]
+ if not cached_values:
+ new_value = variation.evaluate(
+ self._wrapped_variation, None, None, random_state)
+ for values in self._cached_values.values():
+ values.append(new_value)
+ return cached_values.popleft()
+
+
+class _BroadcastedValueProxy(variation.Variation):
+
+ def __init__(self, broadcaster):
+ self._broadcaster = broadcaster
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ value = self._broadcaster._get_value(self, random_state) # pylint: disable=protected-access
+ return value
diff --git a/dm_control/composer/variation/variation_broadcaster_test.py b/dm_control/composer/variation/variation_broadcaster_test.py
new file mode 100644
index 00000000..0f1faa1e
--- /dev/null
+++ b/dm_control/composer/variation/variation_broadcaster_test.py
@@ -0,0 +1,104 @@
+# Copyright 2024 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from absl.testing import absltest
+from dm_control.composer import variation
+from dm_control.composer.variation import distributions
+from dm_control.composer.variation import variation_broadcaster
+import numpy as np
+
+
+class VariationBroadcasterTest(absltest.TestCase):
+
+ def test_can_generate_values(self):
+ random_state = np.random.RandomState(2348)
+ expected_values = [random_state.uniform(0, 1) for _ in range(5)]
+
+ random_state = np.random.RandomState(2348)
+ broadcaster = variation_broadcaster.VariationBroadcaster(
+ distributions.Uniform(0, 1)
+ )
+ proxy_1 = broadcaster.get_proxy()
+ proxy_2 = broadcaster.get_proxy()
+ proxy_3 = broadcaster.get_proxy()
+
+ self.assertEqual(
+ variation.evaluate(proxy_1, random_state=random_state),
+ expected_values[0],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_2, random_state=random_state),
+ expected_values[0],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_3, random_state=random_state),
+ expected_values[0],
+ )
+
+ self.assertEqual(
+ variation.evaluate(proxy_1, random_state=random_state),
+ expected_values[1],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_1, random_state=random_state),
+ expected_values[2],
+ )
+
+ self.assertEqual(
+ variation.evaluate(proxy_2, random_state=random_state),
+ expected_values[1],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_3, random_state=random_state),
+ expected_values[1],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_3, random_state=random_state),
+ expected_values[2],
+ )
+
+ self.assertEqual(
+ variation.evaluate(proxy_3, random_state=random_state),
+ expected_values[3],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_1, random_state=random_state),
+ expected_values[3],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_2, random_state=random_state),
+ expected_values[2],
+ )
+
+ self.assertEqual(
+ variation.evaluate(proxy_1, random_state=random_state),
+ expected_values[4],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_2, random_state=random_state),
+ expected_values[3],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_2, random_state=random_state),
+ expected_values[4],
+ )
+ self.assertEqual(
+ variation.evaluate(proxy_3, random_state=random_state),
+ expected_values[4],
+ )
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/variation/variation_test.py b/dm_control/composer/variation/variation_test.py
new file mode 100644
index 00000000..1c611a9f
--- /dev/null
+++ b/dm_control/composer/variation/variation_test.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for base variation operations."""
+
+import operator
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer import variation
+from dm_control.composer.variation import deterministic
+import numpy as np
+
+
+class VariationTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.value_1 = 3
+ self.variation_1 = deterministic.Constant(self.value_1)
+ self.value_2 = 5
+ self.variation_2 = deterministic.Constant(self.value_2)
+
+ @parameterized.parameters(['neg'])
+ def test_unary_operator(self, name):
+ func = getattr(operator, name)
+ self.assertEqual(
+ variation.evaluate(func(self.variation_1)),
+ func(self.value_1))
+
+ @parameterized.parameters(['add', 'sub', 'mul', 'truediv', 'floordiv', 'pow'])
+ def test_binary_operator(self, name):
+ func = getattr(operator, name)
+ self.assertEqual(
+ variation.evaluate(func(self.value_1, self.variation_2)),
+ func(self.value_1, self.value_2))
+ self.assertEqual(
+ variation.evaluate(func(self.variation_1, self.value_2)),
+ func(self.value_1, self.value_2))
+ self.assertEqual(
+ variation.evaluate(func(self.variation_1, self.variation_2)),
+ func(self.value_1, self.value_2))
+ self.assertEqual(
+ func(self.variation_1, self.variation_2),
+ func(self.variation_1, self.variation_2),
+ )
+ self.assertNotEqual(
+ func(self.variation_1, self.variation_2),
+ func(self.variation_1, self.variation_1),
+ )
+
+ def test_binary_operator_str(self):
+ self.assertEqual(
+ 'add(3, 5)',
+ str(
+ self.variation_1 + self.variation_2
+ ),
+ )
+ self.assertEqual(
+ 'BinaryOperation(add(Constant(3), Constant(5)))',
+ repr(
+ self.variation_1 + self.variation_2
+ ),
+ )
+
+ def test_getitem(self):
+ value = deterministic.Constant(np.array([4, 5, 6, 7, 8]))
+ np.testing.assert_array_equal(
+ variation.evaluate(value[[3, 1]]),
+ [7, 5])
+ self.assertEqual(
+ '[4 5 6 7 8][3]',
+ str(
+ value[3]
+ ),
+ )
+ self.assertEqual(
+ 'GetItemOperation(Constant(array([4, 5, 6, 7, 8]))[3])',
+ repr(
+ value[3]
+ ),
+ )
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/composer/variation/variation_values.py b/dm_control/composer/variation/variation_values.py
new file mode 100644
index 00000000..97adb89d
--- /dev/null
+++ b/dm_control/composer/variation/variation_values.py
@@ -0,0 +1,35 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utilities for handling nested structures of callables or constants."""
+
+import tree
+
+
+def evaluate(structure, *args, **kwargs):
+ """Evaluates a arbitrarily nested structure of callables or constant values.
+
+ Args:
+ structure: An arbitrarily nested structure of callables or constant values.
+ By "structures", we mean lists, tuples, namedtuples, or dicts.
+ *args: Positional arguments passed to each callable in `structure`.
+ **kwargs: Keyword arguments passed to each callable in `structure`.
+
+ Returns:
+ The same nested structure, with each callable replaced by the value returned
+ by calling it.
+ """
+ return tree.map_structure(
+ lambda x: x(*args, **kwargs) if callable(x) else x, structure)
diff --git a/dm_control/entities/__init__.py b/dm_control/entities/__init__.py
new file mode 100644
index 00000000..4224c020
--- /dev/null
+++ b/dm_control/entities/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/entities/manipulators/__init__.py b/dm_control/entities/manipulators/__init__.py
new file mode 100644
index 00000000..dd795512
--- /dev/null
+++ b/dm_control/entities/manipulators/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Composer entities corresponding to robots."""
+
diff --git a/dm_control/entities/manipulators/base.py b/dm_control/entities/manipulators/base.py
new file mode 100644
index 00000000..76d23cbd
--- /dev/null
+++ b/dm_control/entities/manipulators/base.py
@@ -0,0 +1,196 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Abstract base classes for robot arms and hands."""
+
+import abc
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer.observation import observable
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.utils import inverse_kinematics
+import numpy as np
+
+
+DOWN_QUATERNION = np.array([0., 0.70710678118, 0.70710678118, 0.])
+
+_INVALID_JOINTS_ERROR = (
+ 'All non-hinge joints must have limits. Model contains the following '
+ 'non-hinge joints which are unbounded:\n{invalid_str}')
+
+
+class RobotArm(composer.Robot, metaclass=abc.ABCMeta):
+ """The abstract base class for robotic arms."""
+
+ def _build_observables(self):
+ return JointsObservables(self)
+
+ @property
+ def attachment_site(self):
+ return self.wrist_site
+
+ def _get_joint_pos_sampling_bounds(self, physics):
+ """Returns lower and upper bounds for sampling arm joint positions.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+
+ Returns:
+ A (2, num_joints) numpy array containing (lower, upper) position bounds.
+ For hinge joints without limits the bounds are defined as [0, 2pi].
+
+ Raises:
+ RuntimeError: If the model contains unlimited joints that are not hinges.
+ """
+ bound_joints = physics.bind(self.joints)
+ limits = np.array(bound_joints.range, copy=True)
+ is_hinge = bound_joints.type == mjbindings.enums.mjtJoint.mjJNT_HINGE
+ is_limited = bound_joints.limited.astype(bool)
+ invalid = ~is_hinge & ~is_limited # All non-hinge joints must have limits.
+ if any(invalid):
+ invalid_str = '\n'.join(str(self.joints[i]) for i in np.where(invalid)[0])
+ raise RuntimeError(_INVALID_JOINTS_ERROR.format(invalid_str=invalid_str))
+ # For unlimited hinges we sample positions between 0 and 2pi.
+ limits[is_hinge & ~is_limited] = 0., 2*np.pi
+ return limits.T
+
+ def randomize_arm_joints(self, physics, random_state):
+ """Randomizes the qpos of all arm joints.
+
+ The ranges of qpos values is determined from the MJCF model.
+
+ Args:
+ physics: A `mujoco.Physics` instance.
+ random_state: An `np.random.RandomState` instance.
+ """
+ lower, upper = self._get_joint_pos_sampling_bounds(physics)
+ physics.bind(self.joints).qpos = random_state.uniform(lower, upper)
+
+ def set_site_to_xpos(self, physics, random_state, site, target_pos,
+ target_quat=None, max_ik_attempts=10):
+ """Moves the arm so that a site occurs at the specified location.
+
+ This function runs the inverse kinematics solver to find a configuration
+ arm joints for which the pinch site occurs at the specified location in
+ Cartesian coordinates.
+
+ Args:
+ physics: A `mujoco.Physics` instance.
+ random_state: An `np.random.RandomState` instance.
+ site: Either a `mjcf.Element` or a string specifying the full name
+ of the site whose position is being set.
+ target_pos: The desired Cartesian location of the site.
+ target_quat: (optional) The desired orientation of the site, expressed
+ as a quaternion. If `None`, the default orientation is to point
+ vertically downwards.
+ max_ik_attempts: (optional) Maximum number of attempts to make at finding
+ a solution satisfying `target_pos` and `target_quat`. The joint
+ positions will be randomized after each unsuccessful attempt.
+
+ Returns:
+ A boolean indicating whether the desired configuration is obtained.
+
+ Raises:
+ ValueError: If site is neither a string nor an `mjcf.Element`.
+ """
+ if isinstance(site, mjcf.Element):
+ site_name = site.full_identifier
+ elif isinstance(site, str):
+ site_name = site
+ else:
+ raise ValueError('site should either be a string or mjcf.Element: got {}'
+ .format(site))
+ if target_quat is None:
+ target_quat = DOWN_QUATERNION
+ lower, upper = self._get_joint_pos_sampling_bounds(physics)
+ arm_joint_names = [joint.full_identifier for joint in self.joints]
+
+ for _ in range(max_ik_attempts):
+ result = inverse_kinematics.qpos_from_site_pose(
+ physics=physics,
+ site_name=site_name,
+ target_pos=target_pos,
+ target_quat=target_quat,
+ joint_names=arm_joint_names,
+ rot_weight=2,
+ inplace=True)
+ success = result.success
+
+ # Canonicalise the angle to [0, 2*pi]
+ if success:
+ for arm_joint, low, high in zip(self.joints, lower, upper):
+ arm_joint_mj = physics.bind(arm_joint)
+ while arm_joint_mj.qpos >= high:
+ arm_joint_mj.qpos -= 2*np.pi
+ while arm_joint_mj.qpos < low:
+ arm_joint_mj.qpos += 2*np.pi
+ if arm_joint_mj.qpos > high:
+ success = False
+ break
+
+ # If succeeded or only one attempt, break and do not randomize joints.
+ if success or max_ik_attempts <= 1:
+ break
+ else:
+ self.randomize_arm_joints(physics, random_state)
+
+ return success
+
+ @property
+ @abc.abstractmethod
+ def joints(self):
+ """Returns the joint elements of the arm."""
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def wrist_site(self):
+ """Returns the wrist site element of the arm."""
+ raise NotImplementedError
+
+
+class JointsObservables(composer.Observables):
+ """Observables common to all robot arms."""
+
+ @define.observable
+ def joints_pos(self):
+ return observable.MJCFFeature('qpos', self._entity.joints)
+
+ @define.observable
+ def joints_vel(self):
+ return observable.MJCFFeature('qvel', self._entity.joints)
+
+
+class RobotHand(composer.Robot, metaclass=abc.ABCMeta):
+ """The abstract base class for robotic hands."""
+
+ @abc.abstractmethod
+ def set_grasp(self, physics, close_factors):
+ """Sets the finger position to the desired positions.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ close_factors: A number or list of numbers defining the desired grasp
+ position of each finger. A value of 0 corresponds to fully opening a
+ finger, while a value of 1 corresponds to fully closing it. If a single
+ number is specified, the same position is applied to all fingers.
+ """
+
+ @property
+ @abc.abstractmethod
+ def tool_center_point(self):
+ """Returns the tool center point element of the hand."""
diff --git a/dm_control/entities/manipulators/kinova/__init__.py b/dm_control/entities/manipulators/kinova/__init__.py
new file mode 100644
index 00000000..edff28fd
--- /dev/null
+++ b/dm_control/entities/manipulators/kinova/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Composer models of Kinova robots."""
+
+from dm_control.entities.manipulators.kinova.jaco_arm import JacoArm
+from dm_control.entities.manipulators.kinova.jaco_hand import JacoHand
diff --git a/dm_control/entities/manipulators/kinova/assets_path.py b/dm_control/entities/manipulators/kinova/assets_path.py
new file mode 100644
index 00000000..5111d049
--- /dev/null
+++ b/dm_control/entities/manipulators/kinova/assets_path.py
@@ -0,0 +1,25 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Helper module that specifies the path to Kinova assets."""
+
+import importlib
+import os
+
+
+_DM_CONTROL_ROOT = os.path.dirname(
+ importlib.util.find_spec('dm_control').origin)
+
+KINOVA_ROOT = os.path.join(_DM_CONTROL_ROOT, 'third_party/kinova')
diff --git a/dm_control/entities/manipulators/kinova/jaco_arm.py b/dm_control/entities/manipulators/kinova/jaco_arm.py
new file mode 100644
index 00000000..1b491ee5
--- /dev/null
+++ b/dm_control/entities/manipulators/kinova/jaco_arm.py
@@ -0,0 +1,154 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module containing the Jaco robot class."""
+
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer.observation import observable
+from dm_control.entities.manipulators import base
+from dm_control.entities.manipulators.kinova import assets_path
+import numpy as np
+
+_JACO_ARM_XML_PATH = os.path.join(assets_path.KINOVA_ROOT, 'jaco_arm.xml')
+_LARGE_JOINTS = ('joint_1', 'joint_2', 'joint_3')
+_SMALL_JOINTS = ('joint_4', 'joint_5', 'joint_6')
+_ALL_JOINTS = _LARGE_JOINTS + _SMALL_JOINTS
+_WRIST_SITE = 'wristsite'
+
+# These are the peak torque limits taken from Kinova's datasheet:
+# https://www.kinovarobotics.com/sites/default/files/AS-ACT-KA58-KA75-SP-INT-EN%20201804-1.2%20%28KINOVA%E2%84%A2%20Actuator%20series%20KA75%2B%20KA-58%20Specifications%29.pdf
+_LARGE_JOINT_MAX_TORQUE = 30.5
+_SMALL_JOINT_MAX_TORQUE = 6.8
+
+# On the real robot these limits are imposed by the actuator firmware. It's
+# technically possible to exceed them via the low-level API, but this can reduce
+# the lifetime of the actuators.
+_LARGE_JOINT_MAX_VELOCITY = np.deg2rad(36.)
+_SMALL_JOINT_MAX_VELOCITY = np.deg2rad(48.)
+
+# The velocity actuator gain is a very rough estimate, and should be considered
+# a placeholder for proper system identification.
+_VELOCITY_GAIN = 500.
+
+
+class JacoArm(base.RobotArm):
+ """A composer entity representing a Jaco arm."""
+
+ def _build(self, name=None):
+ """Initializes the JacoArm.
+
+ Args:
+ name: String, the name of this robot. Used as a prefix in the MJCF name
+ name attributes.
+ """
+ self._mjcf_root = mjcf.from_path(_JACO_ARM_XML_PATH)
+ if name:
+ self._mjcf_root.model = name
+ # Find MJCF elements that will be exposed as attributes.
+ self._joints = [self._mjcf_root.find('joint', name) for name in _ALL_JOINTS]
+ self._wrist_site = self._mjcf_root.find('site', _WRIST_SITE)
+ self._bodies = self.mjcf_model.find_all('body')
+ # Add actuators.
+ self._actuators = [_add_velocity_actuator(joint) for joint in self._joints]
+ # Add torque sensors.
+ self._joint_torque_sensors = [
+ _add_torque_sensor(joint) for joint in self._joints]
+
+ def _build_observables(self):
+ return JacoArmObservables(self)
+
+ @property
+ def joints(self):
+ """List of joint elements belonging to the arm."""
+ return self._joints
+
+ @property
+ def actuators(self):
+ """List of actuator elements belonging to the arm."""
+ return self._actuators
+
+ @property
+ def joint_torque_sensors(self):
+ """List of torque sensors for each joint belonging to the arm."""
+ return self._joint_torque_sensors
+
+ @property
+ def wrist_site(self):
+ """Wrist site of the arm (attachment point for the hand)."""
+ return self._wrist_site
+
+ @property
+ def mjcf_model(self):
+ """Returns the `mjcf.RootElement` object corresponding to this robot."""
+ return self._mjcf_root
+
+
+def _add_velocity_actuator(joint):
+ """Adds a velocity actuator to a joint, returns the new MJCF element."""
+
+ if joint.name in _LARGE_JOINTS:
+ max_torque = _LARGE_JOINT_MAX_TORQUE
+ max_velocity = _LARGE_JOINT_MAX_VELOCITY
+ elif joint.name in _SMALL_JOINTS:
+ max_torque = _SMALL_JOINT_MAX_TORQUE
+ max_velocity = _SMALL_JOINT_MAX_VELOCITY
+ else:
+ raise ValueError('`joint.name` must be one of {}, got {!r}.'
+ .format(_ALL_JOINTS, joint.name))
+ return joint.root.actuator.add(
+ 'velocity',
+ joint=joint,
+ name=joint.name,
+ kv=_VELOCITY_GAIN,
+ ctrllimited=True,
+ ctrlrange=(-max_velocity, max_velocity),
+ forcelimited=True,
+ forcerange=(-max_torque, max_torque))
+
+
+def _add_torque_sensor(joint):
+ """Adds a torque sensor to a joint, returns the new MJCF element."""
+ site = joint.parent.add(
+ 'site', size=[1e-3], group=composer.SENSOR_SITES_GROUP,
+ name=joint.name+'_site')
+ return joint.root.sensor.add('torque', site=site, name=joint.name+'_torque')
+
+
+class JacoArmObservables(base.JointsObservables):
+ """Jaco arm obserables."""
+
+ @define.observable
+ def joints_pos(self):
+ # Because most of the Jaco arm joints are unlimited, we return the joint
+ # angles as sine/cosine pairs so that the observations are bounded.
+ def get_sin_cos_joint_angles(physics):
+ joint_pos = physics.bind(self._entity.joints).qpos
+ return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T
+ return observable.Generic(get_sin_cos_joint_angles)
+
+ @define.observable
+ def joints_torque(self):
+ # MuJoCo's torque sensors are 3-axis, but we are only interested in torques
+ # acting about the axis of rotation of the joint. We therefore project the
+ # torques onto the joint axis.
+ def get_torques(physics):
+ torques = physics.bind(self._entity.joint_torque_sensors).sensordata
+ joint_axes = physics.bind(self._entity.joints).axis
+ return np.einsum('ij,ij->i', torques.reshape(-1, 3), joint_axes)
+ return observable.Generic(get_torques)
diff --git a/dm_control/entities/manipulators/kinova/jaco_hand.py b/dm_control/entities/manipulators/kinova/jaco_hand.py
new file mode 100644
index 00000000..b2977bac
--- /dev/null
+++ b/dm_control/entities/manipulators/kinova/jaco_hand.py
@@ -0,0 +1,170 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module containing the standard Jaco hand."""
+
+import collections
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.entities.manipulators import base
+from dm_control.entities.manipulators.kinova import assets_path
+
+_JACO_HAND_XML_PATH = os.path.join(assets_path.KINOVA_ROOT, 'jaco_hand.xml')
+_HAND_BODY = 'hand'
+_PINCH_SITE = 'pinchsite'
+_GRIP_SITE = 'gripsite'
+
+
+class JacoHand(base.RobotHand):
+ """A composer entity representing a Jaco hand."""
+
+ def _build(self,
+ name=None,
+ use_pinch_site_as_tcp=False):
+ """Initializes the JacoHand.
+
+ Args:
+ name: String, the name of this robot. Used as a prefix in the MJCF name
+ name attributes.
+ use_pinch_site_as_tcp: (optional) A boolean, if `True` the pinch site
+ will be used as the tool center point. If `False` the grip site is used.
+ """
+ self._mjcf_root = mjcf.from_path(_JACO_HAND_XML_PATH)
+ if name:
+ self._mjcf_root.model = name
+ # Find MJCF elements that will be exposed as attributes.
+ self._bodies = self.mjcf_model.find_all('body')
+ self._tool_center_point = self._mjcf_root.find(
+ 'site', _PINCH_SITE if use_pinch_site_as_tcp else _GRIP_SITE)
+ self._joints = self._mjcf_root.find_all('joint')
+ self._hand_geoms = list(self._mjcf_root.find('body', _HAND_BODY).geom)
+ self._finger_geoms = [geom for geom in self._mjcf_root.find_all('geom')
+ if geom.name and geom.name.startswith('finger')]
+ self._grip_site = self._mjcf_root.find('site', _GRIP_SITE)
+ self._pinch_site = self._mjcf_root.find('site', _PINCH_SITE)
+
+ # Add actuators.
+ self._finger_actuators = [
+ _add_velocity_actuator(joint) for joint in self._joints]
+
+ def _build_observables(self):
+ return JacoHandObservables(self)
+
+ @property
+ def tool_center_point(self):
+ """Tool center point for the Jaco hand."""
+ return self._tool_center_point
+
+ @property
+ def joints(self):
+ """List of joint elements."""
+ return self._joints
+
+ @property
+ def actuators(self):
+ """List of finger actuators."""
+ return self._finger_actuators
+
+ @property
+ def hand_geom(self):
+ """List of geoms belonging to the hand."""
+ return self._hand_geoms
+
+ @property
+ def finger_geoms(self):
+ """List of geoms belonging to the fingers."""
+ return self._finger_geoms
+
+ @property
+ def grip_site(self):
+ """Grip site."""
+ return self._grip_site
+
+ @property
+ def pinch_site(self):
+ """Pinch site."""
+ return self._pinch_site
+
+ @property
+ def pinch_site_pos_sensor(self):
+ """Sensor that returns the cartesian position of the pinch site."""
+ return self._pinch_site_pos_sensor
+
+ @property
+ def pinch_site_quat_sensor(self):
+ """Sensor that returns the orientation of the pinch site as a quaternion."""
+ return self._pinch_site_quat_sensor
+
+ @property
+ def mjcf_model(self):
+ """Returns the `mjcf.RootElement` object corresponding to this robot."""
+ return self._mjcf_root
+
+ def set_grasp(self, physics, close_factors):
+ """Sets the finger position to the desired positions.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ close_factors: A number or list of numbers defining the desired grasp
+ position of each finger. A value of 0 corresponds to fully opening a
+ finger, while a value of 1 corresponds to fully closing it. If a single
+ number is specified, the same position is applied to all fingers.
+ """
+ if not isinstance(close_factors, collections.abc.Iterable):
+ close_factors = (close_factors,) * len(self.joints)
+ for joint, finger_factor in zip(self.joints, close_factors):
+ joint_mj = physics.bind(joint)
+ min_value, max_value = joint_mj.range
+ joint_mj.qpos = min_value + (max_value - min_value) * finger_factor
+ physics.after_reset()
+
+ # Set target joint velocities to zero.
+ physics.bind(self.actuators).ctrl = 0
+
+
+def _add_velocity_actuator(joint):
+ """Adds a velocity actuator to a joint, returns the new MJCF element."""
+ # These parameters were adjusted to achieve a grip force of ~25 N and a finger
+ # closing time of ~1.2 s, as specified in the datasheet for the hand.
+ gain = 10.
+ forcerange = (-1., 1.)
+ ctrlrange = (-5., 5.) # Based on Kinova's URDF.
+ return joint.root.actuator.add(
+ 'velocity',
+ joint=joint,
+ name=joint.name,
+ kv=gain,
+ ctrllimited=True,
+ ctrlrange=ctrlrange,
+ forcelimited=True,
+ forcerange=forcerange)
+
+
+class JacoHandObservables(base.JointsObservables):
+ """Observables for the Jaco hand."""
+
+ @composer.observable
+ def pinch_site_pos(self):
+ """The position of the pinch site, in global coordinates."""
+ return observable.MJCFFeature('xpos', self._entity.pinch_site)
+
+ @composer.observable
+ def pinch_site_rmat(self):
+ """The rotation matrix of the pinch site in global coordinates."""
+ return observable.MJCFFeature('xmat', self._entity.pinch_site)
+
diff --git a/dm_control/entities/manipulators/kinova/kinova_test.py b/dm_control/entities/manipulators/kinova/kinova_test.py
new file mode 100644
index 00000000..88e5fbb8
--- /dev/null
+++ b/dm_control/entities/manipulators/kinova/kinova_test.py
@@ -0,0 +1,281 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the Jaco arm class."""
+
+import itertools
+import unittest
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.entities.manipulators import kinova
+from dm_control.entities.manipulators.kinova import jaco_arm
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+
+class JacoArmTest(parameterized.TestCase):
+
+ def test_can_compile_and_step_model(self):
+ arm = kinova.JacoArm()
+ physics = mjcf.Physics.from_mjcf_model(arm.mjcf_model)
+ physics.step()
+
+ def test_can_attach_hand(self):
+ arm = kinova.JacoArm()
+ hand = kinova.JacoHand()
+ arm.attach(hand)
+ physics = mjcf.Physics.from_mjcf_model(arm.mjcf_model)
+ physics.step()
+
+ # TODO(b/159974149): Investigate why the mass does not match the datasheet.
+ @unittest.expectedFailure
+ def test_mass(self):
+ arm = kinova.JacoArm()
+ physics = mjcf.Physics.from_mjcf_model(arm.mjcf_model)
+ mass = physics.bind(arm.mjcf_model.worldbody).subtreemass
+ expected_mass = 4.4
+ self.assertAlmostEqual(mass, expected_mass)
+
+ @parameterized.parameters([
+ dict(actuator_index=0,
+ control_input=0,
+ expected_velocity=0.),
+ dict(actuator_index=0,
+ control_input=jaco_arm._LARGE_JOINT_MAX_VELOCITY,
+ expected_velocity=jaco_arm._LARGE_JOINT_MAX_VELOCITY),
+ dict(actuator_index=4,
+ control_input=jaco_arm._SMALL_JOINT_MAX_VELOCITY,
+ expected_velocity=jaco_arm._SMALL_JOINT_MAX_VELOCITY),
+ dict(actuator_index=0,
+ control_input=-jaco_arm._LARGE_JOINT_MAX_VELOCITY,
+ expected_velocity=-jaco_arm._LARGE_JOINT_MAX_VELOCITY),
+ dict(actuator_index=0,
+ control_input=2*jaco_arm._LARGE_JOINT_MAX_VELOCITY, # Test clipping
+ expected_velocity=jaco_arm._LARGE_JOINT_MAX_VELOCITY),
+ ])
+ def test_velocity_actuation(
+ self, actuator_index, control_input, expected_velocity):
+ arm = kinova.JacoArm()
+ physics = mjcf.Physics.from_mjcf_model(arm.mjcf_model)
+ actuator = arm.actuators[actuator_index]
+ bound_actuator = physics.bind(actuator)
+ bound_joint = physics.bind(actuator.joint)
+ acceleration_threshold = 1e-6
+ with physics.model.disable('contact', 'gravity'):
+ bound_actuator.ctrl = control_input
+ # Step until the joint has stopped accelerating.
+ while abs(bound_joint.qacc) > acceleration_threshold:
+ physics.step()
+ self.assertAlmostEqual(bound_joint.qvel[0], expected_velocity, delta=0.01)
+
+ @parameterized.parameters([
+ dict(joint_index=0, min_expected_torque=1.7, max_expected_torque=5.2),
+ dict(joint_index=5, min_expected_torque=0.8, max_expected_torque=7.0)])
+ def test_backdriving_torque(
+ self, joint_index, min_expected_torque, max_expected_torque):
+ arm = kinova.JacoArm()
+ physics = mjcf.Physics.from_mjcf_model(arm.mjcf_model)
+ bound_joint = physics.bind(arm.joints[joint_index])
+ torque = min_expected_torque * 0.8
+ velocity_threshold = 0.1*2*np.pi/60. # 0.1 RPM
+ torque_increment = 0.01
+ seconds_per_torque_increment = 1.
+ max_torque = max_expected_torque * 1.1
+ while torque < max_torque:
+ # Ensure that no other forces are acting on the arm.
+ with physics.model.disable('gravity', 'contact', 'actuation'):
+ # Reset the simulation so that the initial velocity is zero.
+ physics.reset()
+ bound_joint.qfrc_applied = torque
+ while physics.time() < seconds_per_torque_increment:
+ physics.step()
+ if bound_joint.qvel[0] >= velocity_threshold:
+ self.assertBetween(torque, min_expected_torque, max_expected_torque)
+ return
+ # If we failed to accelerate the joint to the target velocity within the
+ # time limit we'll reset the simulation and increase the torque.
+ torque += torque_increment
+ self.fail('Torque of {} Nm insufficient to backdrive joint.'.format(torque))
+
+ @parameterized.parameters([
+ dict(joint_pos=0., expected_obs=[0., 1.]),
+ dict(joint_pos=-0.5*np.pi, expected_obs=[-1., 0.]),
+ dict(joint_pos=np.pi, expected_obs=[0., -1.]),
+ dict(joint_pos=10*np.pi, expected_obs=[0., 1.])])
+ def test_joints_pos_observables(self, joint_pos, expected_obs):
+ joint_index = 0
+ arm = kinova.JacoArm()
+ physics = mjcf.Physics.from_mjcf_model(arm.mjcf_model)
+ physics.bind(arm.joints).qpos[joint_index] = joint_pos
+ actual_obs = arm.observables.joints_pos(physics)[joint_index]
+ np.testing.assert_array_almost_equal(expected_obs, actual_obs)
+
+ @parameterized.parameters(
+ dict(joint_index=idx, applied_torque=t)
+ for idx, t in itertools.product([0, 2, 4], [0., -6.8, 30.5]))
+ def test_joints_torque_observables(self, joint_index, applied_torque):
+ arm = kinova.JacoArm()
+ joint = arm.joints[joint_index]
+ physics = mjcf.Physics.from_mjcf_model(arm.mjcf_model)
+ with physics.model.disable('gravity', 'limit', 'contact', 'actuation'):
+ # Apply a cartesian torque to the body containing the joint. We use
+ # `xfrc_applied` rather than `qfrc_applied` because forces in
+ # `qfrc_applied` are not measured by the torque sensor).
+ physics.bind(joint.parent).xfrc_applied[3:] = (
+ applied_torque * physics.bind(joint).xaxis)
+ observed_torque = arm.observables.joints_torque(physics)[joint_index]
+ # Note the change in sign, since the sensor measures torques in the
+ # child->parent direction.
+ self.assertAlmostEqual(observed_torque, -applied_torque, delta=0.1)
+
+
+class JacoHandTest(parameterized.TestCase):
+
+ def test_can_compile_and_step_model(self):
+ hand = kinova.JacoHand()
+ physics = mjcf.Physics.from_mjcf_model(hand.mjcf_model)
+ physics.step()
+
+ # TODO(b/159974149): Investigate why the mass does not match the datasheet.
+ @unittest.expectedFailure
+ def test_hand_mass(self):
+ hand = kinova.JacoHand()
+ physics = mjcf.Physics.from_mjcf_model(hand.mjcf_model)
+ mass = physics.bind(hand.mjcf_model.worldbody).subtreemass
+ expected_mass = 0.727
+ self.assertAlmostEqual(mass, expected_mass)
+
+ def test_grip_force(self):
+ arena = composer.Arena()
+ hand = kinova.JacoHand()
+ arena.attach(hand)
+
+ # A sphere with a touch sensor for measuring grip force.
+ prop_model = mjcf.RootElement(model='grip_target')
+ prop_model.worldbody.add('geom', type='sphere', size=[0.02])
+ touch_site = prop_model.worldbody.add('site', type='sphere', size=[0.025])
+ touch_sensor = prop_model.sensor.add('touch', site=touch_site)
+ prop = composer.ModelWrapperEntity(prop_model)
+
+ # Add some slide joints to allow movement of the target in the XY plane.
+ # This helps the contact solver to converge more reliably.
+ prop_frame = arena.attach(prop)
+ prop_frame.add('joint', name='slide_x', type='slide', axis=(1, 0, 0))
+ prop_frame.add('joint', name='slide_y', type='slide', axis=(0, 1, 0))
+
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ bound_pinch_site = physics.bind(hand.pinch_site)
+ bound_actuators = physics.bind(hand.actuators)
+ bound_joints = physics.bind(hand.joints)
+ bound_touch = physics.bind(touch_sensor)
+
+ # Position the grip target at the pinch site.
+ prop.set_pose(physics, position=bound_pinch_site.xpos)
+
+ # Close the fingers with as much force as the actuators will allow.
+ bound_actuators.ctrl = bound_actuators.ctrlrange[:, 1]
+
+ # Run the simulation forward until the joints stop moving.
+ physics.step()
+ qvel_thresh = 1e-3 # radians / s
+ while max(abs(bound_joints.qvel)) > qvel_thresh:
+ physics.step()
+ expected_min_grip_force = 20.
+ expected_max_grip_force = 30.
+ grip_force = bound_touch.sensordata
+ self.assertBetween(
+ grip_force, expected_min_grip_force, expected_max_grip_force,
+ msg='Expected grip force to be between {} and {} N, got {} N.'.format(
+ expected_min_grip_force, expected_max_grip_force, grip_force))
+
+ @parameterized.parameters([dict(opening=True), dict(opening=False)])
+ def test_finger_travel_time(self, opening):
+ hand = kinova.JacoHand()
+ physics = mjcf.Physics.from_mjcf_model(hand.mjcf_model)
+ bound_actuators = physics.bind(hand.actuators)
+ bound_joints = physics.bind(hand.joints)
+ min_ctrl, max_ctrl = bound_actuators.ctrlrange.T
+ min_qpos, max_qpos = bound_joints.range.T
+
+ # Measure the time taken for the finger joints to traverse 99.9% of their
+ # total range.
+ qpos_tol = 1e-3 * (max_qpos - min_qpos)
+ if opening:
+ hand.set_grasp(physics=physics, close_factors=1.) # Fully closed.
+ np.testing.assert_array_almost_equal(bound_joints.qpos, max_qpos)
+ target_pos = min_qpos # Fully open.
+ ctrl = min_ctrl # Open the fingers as fast as the actuators will allow.
+ else:
+ hand.set_grasp(physics=physics, close_factors=0.) # Fully open.
+ np.testing.assert_array_almost_equal(bound_joints.qpos, min_qpos)
+ target_pos = max_qpos # Fully closed.
+ ctrl = max_ctrl # Close the fingers as fast as the actuators will allow.
+
+ # Run the simulation until all joints have reached their target positions.
+ bound_actuators.ctrl = ctrl
+ while np.any(abs(bound_joints.qpos - target_pos) > qpos_tol):
+ with physics.model.disable('gravity'):
+ physics.step()
+ expected_travel_time = 1.2 # Seconds.
+ self.assertAlmostEqual(physics.time(), expected_travel_time, delta=0.1)
+
+ @parameterized.parameters([
+ dict(pos=np.r_[0., 0., 0.3], quat=np.r_[0., 1., 0., 1.]),
+ dict(pos=np.r_[0., -0.1, 0.5], quat=np.r_[1., 1., 0., 0.]),
+ ])
+ def test_pinch_site_observables(self, pos, quat):
+ arm = kinova.JacoArm()
+ hand = kinova.JacoHand()
+ arena = composer.Arena()
+ arm.attach(hand)
+ arena.attach(arm)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ # Normalize the quaternion.
+ quat /= np.linalg.norm(quat)
+
+ # Drive the arm so that the pinch site is at the desired position and
+ # orientation.
+ success = arm.set_site_to_xpos(
+ physics=physics,
+ random_state=np.random.RandomState(0),
+ site=hand.pinch_site,
+ target_pos=pos,
+ target_quat=quat)
+ self.assertTrue(success)
+
+ # Check that the observations are as expected.
+ observed_pos = hand.observables.pinch_site_pos(physics)
+ np.testing.assert_allclose(observed_pos, pos, atol=1e-3)
+
+ observed_rmat = hand.observables.pinch_site_rmat(physics).reshape(3, 3)
+ expected_rmat = np.empty((3, 3), np.double)
+ mjlib.mju_quat2Mat(expected_rmat.ravel(), quat)
+ difference_rmat = observed_rmat.dot(expected_rmat.T)
+ # `difference_rmat` might not be perfectly orthonormal, which could lead to
+ # an invalid value being passed to arccos.
+ u, _, vt = np.linalg.svd(difference_rmat, full_matrices=False)
+ ortho_difference_rmat = u.dot(vt)
+ angular_difference = np.arccos((np.trace(ortho_difference_rmat) - 1) / 2)
+ self.assertLess(angular_difference, 1e-3)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/entities/props/__init__.py b/dm_control/entities/props/__init__.py
new file mode 100644
index 00000000..55e00f05
--- /dev/null
+++ b/dm_control/entities/props/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Composer entities corresponding to props.
+
+A "prop" is typically a non-actuated entity representing an object in the world.
+"""
+
+from dm_control.entities.props.duplo import Duplo
+from dm_control.entities.props.position_detector import PositionDetector
+from dm_control.entities.props.primitive import Primitive
diff --git a/dm_control/entities/props/duplo/__init__.py b/dm_control/entities/props/duplo/__init__.py
new file mode 100644
index 00000000..9605e49c
--- /dev/null
+++ b/dm_control/entities/props/duplo/__init__.py
@@ -0,0 +1,169 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A 2x4 Duplo brick."""
+
+import collections
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer.observation import observable
+import numpy as np
+
+_DUPLO_XML_PATH = os.path.join(os.path.dirname(__file__), 'duplo2x4.xml')
+
+# Stud radii are drawn from a uniform distribution. The `variation` argument
+# scales the minimum and maximum whilst keeping the lower quartile constant.
+_StudSize = collections.namedtuple(
+ '_StudSize', ['minimum', 'lower_quartile', 'maximum'])
+_StudParams = collections.namedtuple('_StudParams', ['easy_align', 'flanges'])
+
+_STUD_SIZE_PARAMS = {
+ _StudParams(easy_align=False, flanges=False):
+ _StudSize(minimum=0.004685, lower_quartile=0.004781, maximum=0.004898),
+ _StudParams(easy_align=False, flanges=True):
+ _StudSize(minimum=0.004609, lower_quartile=0.004647, maximum=0.004716),
+ _StudParams(easy_align=True, flanges=False):
+ _StudSize(minimum=0.004754, lower_quartile=0.004844, maximum=0.004953),
+ _StudParams(easy_align=True, flanges=True):
+ _StudSize(minimum=0.004695, lower_quartile=0.004717, maximum=0.004765)
+}
+
+_COLOR_NOT_BETWEEN_0_AND_1 = (
+ 'All values in `color` must be between 0 and 1, got {!r}.')
+
+
+class Duplo(composer.Entity):
+ """A 2x4 Duplo brick."""
+
+ def _build(self, easy_align=False, flanges=True, variation=0.0,
+ color=(1., 0., 0.)):
+ """Initializes a new `Duplo` instance.
+
+ Args:
+ easy_align: If True, the studs on the top of the brick will be capsules
+ rather than cylinders. This makes alignment easier.
+ flanges: Whether to use flanges on the bottom of the brick. These make the
+ dynamics more expensive, but allow the bricks to be clicked together in
+ partially overlapping configurations.
+ variation: A float that controls the amount of variation in stud size (and
+ therefore separation force). A value of 1.0 results in a distribution of
+ separation forces that approximately matches the empirical distribution
+ measured for real Duplo bricks. A value of 0.0 yields a deterministic
+ separation force approximately equal to the mode of the empirical
+ distribution.
+ color: An optional tuple of (R, G, B) values specifying the color of the
+ Duplo brick. These should be floats between 0 and 1. The default is red.
+
+ Raises:
+ ValueError: If `color` contains any value that is not between 0 and 1.
+ """
+ self._mjcf_root = mjcf.from_path(_DUPLO_XML_PATH)
+
+ stud = self._mjcf_root.default.find('default', 'stud')
+ if easy_align:
+ # Make cylindrical studs invisible and disable contacts.
+ stud.geom.group = 3
+ stud.geom.contype = 9
+ stud.geom.conaffinity = 8
+ # Make capsule studs visible and enable contacts.
+ stud_cap = self._mjcf_root.default.find('default', 'stud-capsule')
+ stud_cap.geom.group = 0
+ stud_cap.geom.contype = 0
+ stud_cap.geom.conaffinity = 4
+ self._active_stud_dclass = stud_cap
+ else:
+ self._active_stud_dclass = stud
+
+ if flanges:
+ flange_dclass = self._mjcf_root.default.find('default', 'flange')
+ flange_dclass.geom.contype = 4 # Enable contact with flanges.
+
+ stud_size = _STUD_SIZE_PARAMS[(easy_align, flanges)]
+ offset = (1 - variation) * stud_size.lower_quartile
+ self._lower = offset + variation * stud_size.minimum
+ self._upper = offset + variation * stud_size.maximum
+
+ self._studs = np.ndarray((2, 4), dtype=object)
+ self._holes = np.ndarray((2, 4), dtype=object)
+
+ for row in range(2):
+ for column in range(4):
+ self._studs[row, column] = self._mjcf_root.find(
+ 'site', 'stud_{}{}'.format(row, column))
+ self._holes[row, column] = self._mjcf_root.find(
+ 'site', 'hole_{}{}'.format(row, column))
+
+ if not all(0 <= value <= 1 for value in color):
+ raise ValueError(_COLOR_NOT_BETWEEN_0_AND_1.format(color))
+ self._mjcf_root.default.geom.rgba[:3] = color
+
+ def initialize_episode_mjcf(self, random_state):
+ """Randomizes the stud radius (and therefore the separation force)."""
+ radius = random_state.uniform(self._lower, self._upper)
+ self._active_stud_dclass.geom.size[0] = radius
+
+ def _build_observables(self):
+ return DuploObservables(self)
+
+ @property
+ def studs(self):
+ """A (2, 4) numpy array of `mjcf.Elements` corresponding to stud sites."""
+ return self._studs
+
+ @property
+ def holes(self):
+ """A (2, 4) numpy array of `mjcf.Elements` corresponding to hole sites."""
+ return self._holes
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+
+class DuploObservables(composer.Observables, composer.FreePropObservableMixin):
+ """Observables for the `Duplo` prop."""
+
+ @define.observable
+ def position(self):
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.find('sensor', 'position'))
+
+ @define.observable
+ def orientation(self):
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.find('sensor', 'orientation'))
+
+ @define.observable
+ def linear_velocity(self):
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.find('sensor', 'linear_velocity'))
+
+ @define.observable
+ def angular_velocity(self):
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.find('sensor', 'angular_velocity'))
+
+ @define.observable
+ def force(self):
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.find('sensor', 'force'))
diff --git a/dm_control/entities/props/duplo/autotune.py b/dm_control/entities/props/duplo/autotune.py
new file mode 100644
index 00000000..28eed392
--- /dev/null
+++ b/dm_control/entities/props/duplo/autotune.py
@@ -0,0 +1,160 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Script for tuning Duplo stud sizes to give desired separation forces."""
+
+import collections
+import pprint
+from absl import app
+from absl import logging
+from dm_control.entities.props import duplo
+from dm_control.entities.props.duplo import utils
+from scipy import optimize
+
+# pylint: disable=protected-access,invalid-name
+_StudSize = duplo._StudSize
+ORIGINAL_STUD_SIZE_PARAMS = duplo._STUD_SIZE_PARAMS
+# pylint: enable=protected-access,invalid-name
+
+DESIRED_FORCES = _StudSize(minimum=6., lower_quartile=10., maximum=18.)
+
+# The safety margin here is because the separation force isn't quite monotonic
+# w.r.t. the stud radius. If we set the min and max radii according to the
+# exact desired bounds on the separation force then we may occasionally sample
+# stud radii that yield out-of-bounds forces.
+SAFETY_MARGIN = 0.2
+
+
+def get_separation_force_for_radius(radius, **duplo_kwargs):
+ """Measures Duplo separation force as a function of stud radius."""
+
+ top_brick = duplo.Duplo(**duplo_kwargs)
+ bottom_brick = duplo.Duplo(**duplo_kwargs)
+
+ # Set the radius of the studs on the bottom brick (this would normally be done
+ # in `initialize_episode_mjcf`). Note: we also set the radius of the studs on
+ # the top brick, since this has a (tiny!) effect on its mass.
+
+ # pylint: disable=protected-access
+ top_brick._active_stud_dclass.geom.size[0] = radius
+ bottom_brick._active_stud_dclass.geom.size[0] = radius
+ # pylint: enable=protected-access
+
+ separation_force = utils.measure_separation_force(top_brick, bottom_brick)
+ logging.debug('Stud radius: %f\tseparation force: %f N',
+ radius, separation_force)
+ return separation_force
+
+
+class _KeepBracketingSolutions:
+ """Wraps objective func, keeps closest solutions bracketing the target."""
+
+ _solution = collections.namedtuple('_solution', ['x', 'residual'])
+
+ def __init__(self, func):
+ self._func = func
+ self.below = self._solution(x=None, residual=-float('inf'))
+ self.above = self._solution(x=None, residual=float('inf'))
+
+ def __call__(self, x):
+ residual = self._func(x)
+ if self.below.residual < residual <= 0:
+ self.below = self._solution(x=x, residual=residual)
+ elif 0 < residual < self.above.residual:
+ self.above = self._solution(x=x, residual=residual)
+ return residual
+
+ @property
+ def closest(self):
+ if abs(self.below.residual) < self.above.residual:
+ return self.below
+ else:
+ return self.above
+
+
+def tune_stud_radius(desired_force,
+ min_radius=0.0045,
+ max_radius=0.005,
+ desired_places=6,
+ side='closest',
+ **duplo_kwargs):
+ """Find a stud size that gives the desired separation force."""
+
+ @_KeepBracketingSolutions
+ def func(radius):
+ radius = round(radius, desired_places) # Round radius for aesthetics (!)
+ return (get_separation_force_for_radius(radius=radius, **duplo_kwargs)
+ - desired_force)
+
+ # Ensure that the min and max radii bracket the solution.
+ while func(min_radius) > 0:
+ min_radius = max(1e-3, min_radius - (max_radius - min_radius))
+ while func(max_radius) < 0:
+ max_radius += (max_radius - min_radius)
+
+ tolerance = 10**-(desired_places)
+
+ # Use bisection to refine the bounds on the optimal radius. Note: this assumes
+ # that separation force is monotonic w.r.t. stud radius, but this isn't
+ # exactly true in all cases.
+ optimize.bisect(func, a=min_radius, b=max_radius, xtol=tolerance, disp=True)
+
+ if side == 'below':
+ solution = func.below
+ elif side == 'above':
+ solution = func.above
+ else:
+ solution = func.closest
+
+ radius = round(solution.x, desired_places)
+ force = get_separation_force_for_radius(radius, **duplo_kwargs)
+
+ return radius, force
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ tuned_stud_radii = {}
+ tuned_separation_forces = {}
+
+ for stud_params in sorted(ORIGINAL_STUD_SIZE_PARAMS):
+ duplo_kwargs = stud_params._asdict()
+
+ min_result = tune_stud_radius(
+ desired_force=DESIRED_FORCES.minimum + SAFETY_MARGIN,
+ variation=0.0, side='above', **duplo_kwargs)
+ lq_result = tune_stud_radius(
+ desired_force=DESIRED_FORCES.lower_quartile,
+ variation=0.0, side='closest', **duplo_kwargs)
+ max_result = tune_stud_radius(
+ desired_force=DESIRED_FORCES.maximum - SAFETY_MARGIN,
+ variation=0.0, side='below', **duplo_kwargs)
+
+ radii, forces = zip(*(min_result, lq_result, max_result))
+
+ logging.info('\nDuplo configuration: %s\nTuned radii: %s, forces: %s',
+ stud_params, radii, forces)
+ tuned_stud_radii[stud_params] = _StudSize(*radii)
+ tuned_separation_forces[stud_params] = _StudSize(*forces)
+
+ logging.info('%s\nNew Duplo parameters:\n%s\nSeparation forces:\n%s',
+ '-'*60,
+ pprint.pformat(tuned_stud_radii),
+ pprint.pformat(tuned_separation_forces))
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/entities/props/duplo/duplo2x4.xml b/dm_control/entities/props/duplo/duplo2x4.xml
new file mode 100644
index 00000000..324502bc
--- /dev/null
+++ b/dm_control/entities/props/duplo/duplo2x4.xml
@@ -0,0 +1,112 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/entities/props/duplo/duplo_test.py b/dm_control/entities/props/duplo/duplo_test.py
new file mode 100644
index 00000000..f9628068
--- /dev/null
+++ b/dm_control/entities/props/duplo/duplo_test.py
@@ -0,0 +1,154 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the Duplo prop."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.entities.props import duplo
+from dm_control.entities.props.duplo import utils
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+# Expected separation force when `variation == 0`
+EXPECTED_FIXED_FORCE = 10.0
+EXPECTED_FIXED_FORCE_TOL = 0.5
+
+# Bounds and median are based on empirical distribution of separation forces
+# for real Duplo blocks.
+EXPECTED_MIN_FORCE = 6.
+EXPECTED_MAX_FORCE = 18.
+EXPECTED_MEDIAN_FORCE = 12.
+EXPECTED_MEDIAN_FORCE_TOL = 2.
+
+
+class DuploTest(parameterized.TestCase):
+ """Tests for the Duplo prop."""
+
+ def make_bricks(self, seed, *args, **kwargs):
+ top_brick = duplo.Duplo(*args, **kwargs)
+ bottom_brick = duplo.Duplo(*args, **kwargs)
+ # This sets the radius of the studs. NB: we do this for both bricks because
+ # the stud radius has a (tiny!) effect on the mass of the top brick.
+ top_brick.initialize_episode_mjcf(np.random.RandomState(seed))
+ bottom_brick.initialize_episode_mjcf(np.random.RandomState(seed))
+ return top_brick, bottom_brick
+
+ def measure_separation_force(self, seed, *args, **kwargs):
+ top_brick, bottom_brick = self.make_bricks(seed=seed, *args, **kwargs)
+ return utils.measure_separation_force(top_brick, bottom_brick)
+
+ @parameterized.parameters([p._asdict() for p in duplo._STUD_SIZE_PARAMS])
+ def test_separation_force_fixed(self, easy_align, flanges):
+ forces = []
+ for seed in range(3):
+ forces.append(self.measure_separation_force(
+ seed=seed, easy_align=easy_align, flanges=flanges, variation=0.0))
+
+ # Separation forces should all be identical since variation == 0.0.
+ np.testing.assert_array_equal(forces[0], forces[1:])
+
+ # Separation forces should be close to the reference value.
+ self.assertAlmostEqual(forces[0], EXPECTED_FIXED_FORCE,
+ delta=EXPECTED_FIXED_FORCE_TOL)
+
+ @parameterized.parameters([p._asdict() for p in duplo._STUD_SIZE_PARAMS])
+ def test_separation_force_distribution(self, easy_align, flanges):
+ forces = []
+ for seed in range(10):
+ forces.append(self.measure_separation_force(
+ seed=seed, easy_align=easy_align, flanges=flanges, variation=1.0))
+
+ self.assertGreater(min(forces), EXPECTED_MIN_FORCE)
+ self.assertLess(max(forces), EXPECTED_MAX_FORCE)
+ median_force = np.median(forces)
+ median_force_delta = median_force - EXPECTED_MEDIAN_FORCE
+ self.assertLess(
+ abs(median_force_delta), EXPECTED_MEDIAN_FORCE_TOL,
+ msg=('Expected median separation force to be {}+/-{} N, got {} N.'
+ .format(EXPECTED_MEDIAN_FORCE, EXPECTED_MEDIAN_FORCE_TOL,
+ median_force)))
+
+ @parameterized.parameters([p._asdict() for p in duplo._STUD_SIZE_PARAMS])
+ def test_separation_force_identical_with_same_seed(self, easy_align, flanges):
+ def measure(seed):
+ return self.measure_separation_force(
+ seed=seed, easy_align=easy_align, flanges=flanges, variation=1.0)
+
+ first = measure(seed=0)
+ second = measure(seed=0)
+ third = measure(seed=1)
+
+ self.assertEqual(first, second)
+ self.assertNotEqual(first, third)
+
+ def test_exception_if_color_out_of_range(self):
+ invalid_color = (1., 0., 2.)
+ expected_message = duplo._COLOR_NOT_BETWEEN_0_AND_1.format(invalid_color)
+ with self.assertRaisesWithLiteralMatch(ValueError, expected_message):
+ _ = duplo.Duplo(color=invalid_color)
+
+ @parameterized.parameters([p._asdict() for p in duplo._STUD_SIZE_PARAMS])
+ def test_stud_and_hole_sites_align_when_stacked(self, easy_align, flanges):
+ top_brick, bottom_brick = self.make_bricks(
+ easy_align=easy_align, flanges=flanges, seed=0)
+ arena, _ = utils.stack_bricks(top_brick, bottom_brick)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ # Step the physics a few times to allow it to settle.
+ for _ in range(10):
+ physics.step()
+ # When two bricks are stacked, the studs on the bottom brick should align
+ # precisely with the holes on the top brick.
+ bottom_stud_pos = physics.bind(bottom_brick.studs.ravel()).xpos
+ top_hole_pos = physics.bind(top_brick.holes.ravel()).xpos
+ np.testing.assert_allclose(bottom_stud_pos, top_hole_pos, atol=1e-6)
+
+ # TODO(b/120829077): Extend this test to other brick configurations.
+ def test_correct_stud_contacts(self):
+ top_brick, bottom_brick = self.make_bricks(seed=0)
+ arena, _ = utils.stack_bricks(top_brick, bottom_brick)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ # Step the physics a few times to allow it to settle.
+ for _ in range(10):
+ physics.step()
+
+ # Each stud should make 3 contacts - two with flanges, one with a tube.
+ expected_contacts_per_stud = 3
+
+ for stud_site in bottom_brick.studs.flat:
+ stud_geom = bottom_brick.mjcf_model.find('geom', stud_site.name)
+ geom_id = physics.bind(stud_geom).element_id
+
+ # Check that this stud participates in the expected number of contacts.
+ stud_contacts = ((physics.data.contact.geom1 == geom_id) ^
+ (physics.data.contact.geom2 == geom_id))
+ self.assertEqual(stud_contacts.sum(), expected_contacts_per_stud)
+
+ # The normal forces should be roughly equal across contacts.
+ normal_forces = []
+ for contact_id in np.where(stud_contacts)[0]:
+ all_forces = np.empty(6)
+ mjlib.mj_contactForce(physics.model.ptr, physics.data.ptr,
+ contact_id, all_forces)
+ # all_forces is [normal, tangent, tangent, torsion, rolling, rolling]
+ normal_forces.append(all_forces[0])
+ np.testing.assert_allclose(
+ normal_forces[0], normal_forces[1:], rtol=0.05)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/entities/props/duplo/utils.py b/dm_control/entities/props/duplo/utils.py
new file mode 100644
index 00000000..63ddc8f5
--- /dev/null
+++ b/dm_control/entities/props/duplo/utils.py
@@ -0,0 +1,91 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utilities used in tests, and for tuning the Duplo model."""
+
+
+from dm_control import composer
+from dm_control import mjcf
+from scipy import optimize
+
+
+def stack_bricks(top_brick, bottom_brick):
+ """Stacks two Duplo bricks, returns the attachment frame of the top brick."""
+ arena = composer.Arena()
+ # Bottom brick is fixed in place, top brick has a freejoint.
+ arena.attach(bottom_brick)
+ attachment_frame = arena.add_free_entity(top_brick)
+ # Attachment frame is positioned such that the top brick is on top of the
+ # bottom brick.
+ attachment_frame.pos = (0, 0, 0.0192)
+ return arena, attachment_frame
+
+
+def measure_separation_force(top_brick,
+ bottom_brick,
+ min_force=0.,
+ max_force=20.,
+ tolerance=0.01,
+ time_limit=0.5,
+ height_threshold=1e-3):
+ """Utility for measuring the separation force for a pair of Duplo bricks.
+
+ Args:
+ top_brick: An instance of `Duplo` representing the top brick.
+ bottom_brick: An instance of `Duplo` representing the bottom brick.
+ min_force: A force that should be insufficent to separate the bricks (N).
+ max_force: A force that should be sufficent to separate the bricks (N).
+ tolerance: The desired precision of the solution (N).
+ time_limit: The maximum simulation time (s) over which to apply force on
+ each iteration. Increasing this value will result in smaller estimates
+ of the separation force, since given sufficient time the bricks may slip
+ apart gradually under a smaller force. This is due to MuJoCo's soft
+ contact model (see http://mujoco.org/book/index.html#Soft).
+ height_threshold: The distance (m) that the upper brick must move in the
+ z-axis for the bricks to count as separated.
+
+ Returns:
+ A float, the measured separation force (N).
+ """
+ arena, attachment_frame = stack_bricks(top_brick, bottom_brick)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ bound_attachment_frame = physics.bind(attachment_frame)
+
+ def func(force):
+ """Returns +1 if the bricks separate under this force, and -1 otherwise."""
+ with physics.model.disable('gravity'):
+ # Reset the simulation.
+ physics.reset()
+ # Get the initial height.
+ initial_height = bound_attachment_frame.xpos[2]
+ # Apply an upward force to the attachment frame.
+ bound_attachment_frame.xfrc_applied[2] = force
+ # Advance the simulation until either the height threshold or time limit
+ # is reached.
+ while physics.time() < time_limit:
+ physics.step()
+ distance_lifted = bound_attachment_frame.xpos[2] - initial_height
+ if distance_lifted > height_threshold:
+ return 1.0
+ return -1.0
+
+ # Ensure that the min and max forces bracket the true separation force.
+ while func(min_force) > 0:
+ min_force *= 0.5
+ while func(max_force) < 0:
+ max_force *= 2
+
+ return optimize.bisect(func, a=min_force, b=max_force, xtol=tolerance,
+ disp=True)
diff --git a/dm_control/entities/props/position_detector.py b/dm_control/entities/props/position_detector.py
new file mode 100644
index 00000000..7aab1db6
--- /dev/null
+++ b/dm_control/entities/props/position_detector.py
@@ -0,0 +1,293 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Detects the presence of registered entities within a cuboidal region."""
+
+
+from dm_control import composer
+from dm_control import mjcf
+import numpy as np
+
+_RENDERED_HEIGHT_IN_2D_MODE = 0.01
+
+
+def _ensure_3d(pos):
+ # Pad the array with a zero if its length is 2.
+ if len(pos) == 2:
+ return np.hstack([pos, 0.])
+ return pos
+
+
+class _Detection:
+
+ __slots__ = ('entity', 'detected')
+
+ def __init__(self, entity, detected=False):
+ self.entity = entity
+ self.detected = detected
+
+
+class PositionDetector(composer.Entity):
+ """Detects the presence of registered entities within an axis-aligned box.
+
+ The volume of this detector is defined by a "lower" corner and an "upper"
+ corner, which suffice to define an axis-aligned box.
+ An entity is considered "detected" if the `xpos` value of any one of its geom
+ lies within the active region defined by this detector. Note that this is NOT
+ a contact-based detector. Generally speaking, a geom will not be detected
+ until it is already "half inside" the region.
+
+ This detector supports both 2D and 3D modes. In 2D mode, the active region
+ has an effective infinite height along the z-direction.
+
+ This detector also provides an "inverted" detection mode, where an entity is
+ detected when it is not inside the detector's region.
+ """
+
+ def _build(self,
+ pos,
+ size,
+ inverted=False,
+ visible=False,
+ rgba=(1, 1, 1, 1),
+ material=None,
+ detected_rgba=(0, 1, 0, 0.25),
+ retain_substep_detections=False,
+ name='position_detector'):
+ """Builds the detector.
+
+ Args:
+ pos: The position at the center of this detector's active region. Should
+ be an array-like object of length 3 in 3D mode, or length 2 in 2D mode.
+ size: The half-lengths of this detector's active region. Should
+ be an array-like object of length 3 in 3D mode, or length 2 in 2D mode.
+ inverted: (optional) A boolean, whether to operate in inverted detection
+ mode. If `True`, an entity is detected when it is not in the active
+ region.
+ visible: (optional) A boolean, whether this detector is visible by
+ default in rendered images. If `False`, this detector's active zone
+ is placed in MuJoCo rendering group 4, which is not rendered by default,
+ but can be toggled on (e.g. in `dm_control.viewer`) for debugging
+ purposes.
+ rgba: (optional) The color to render when nothing is detected.
+ material: (optional) The material of the position detector.
+ detected_rgba: (optional) The color to render when an entity is detected.
+ retain_substep_detections: (optional) If `True`, the detector will remain
+ activated at the end of a control step if it became activated at any
+ substep. If `False`, the detector reports its instantaneous state.
+ name: (optional) XML element name of this position detector.
+
+ Raises:
+ ValueError: If the `pos` and `size` arrays do not have the same length.
+ """
+ if len(pos) != len(size):
+ raise ValueError('`pos` and `size` should have the same length: '
+ 'got {!r} and {!r}'.format(pos, size))
+
+ self._inverted = inverted
+ self._detected = False
+ self._retain_substep_detections = retain_substep_detections
+ self._lower = np.array(pos) - np.array(size)
+ self._upper = np.array(pos) + np.array(size)
+ self._lower_3d = _ensure_3d(self._lower)
+ self._upper_3d = _ensure_3d(self._upper)
+ self._mid_3d = (self._lower_3d + self._upper_3d) / 2.
+
+ self._entities = []
+ self._entity_geoms = {}
+
+ self._rgba = np.asarray(rgba)
+ self._detected_rgba = np.asarray(detected_rgba)
+
+ render_pos = np.zeros(3)
+ render_pos[:len(pos)] = pos
+
+ render_size = np.full(3, _RENDERED_HEIGHT_IN_2D_MODE)
+ render_size[:len(size)] = size
+
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._site = self._mjcf_root.worldbody.add(
+ 'site', name='detection_zone', type='box',
+ pos=render_pos, size=render_size, rgba=self._rgba, material=material)
+ self._lower_site = self._mjcf_root.worldbody.add(
+ 'site', name='lower', pos=self._lower_3d, size=[0.05],
+ rgba=self._rgba)
+ self._mid_site = self._mjcf_root.worldbody.add(
+ 'site', name='mid', pos=self._mid_3d, size=[0.05],
+ rgba=self._rgba)
+ self._upper_site = self._mjcf_root.worldbody.add(
+ 'site', name='upper', pos=self._upper_3d, size=[0.05],
+ rgba=self._rgba)
+ self._lower_sensor = self._mjcf_root.sensor.add(
+ 'framepos', objtype='site', objname=self._lower_site,
+ name='{}_lower'.format(name))
+ self._mid_sensor = self._mjcf_root.sensor.add(
+ 'framepos', objtype='site', objname=self._mid_site,
+ name='{}_mid'.format(name))
+ self._upper_sensor = self._mjcf_root.sensor.add(
+ 'framepos', objtype='site', objname=self._upper_site,
+ name='{}_upper'.format(name))
+
+ if not visible:
+ self._site.group = composer.SENSOR_SITES_GROUP
+ self._lower_site.group = composer.SENSOR_SITES_GROUP
+ self._mid_site.group = composer.SENSOR_SITES_GROUP
+ self._upper_site.group = composer.SENSOR_SITES_GROUP
+
+ def resize(self, pos, size):
+ if len(pos) != len(size):
+ raise ValueError('`pos` and `size` should have the same length: '
+ 'got {!r} and {!r}'.format(pos, size))
+ self._lower = np.array(pos) - np.array(size)
+ self._upper = np.array(pos) + np.array(size)
+
+ self._lower_3d = _ensure_3d(self._lower)
+ self._upper_3d = _ensure_3d(self._upper)
+ self._mid_3d = (self._lower_3d + self._upper_3d) / 2.
+
+ render_pos = np.zeros(3)
+ render_pos[:len(pos)] = pos
+
+ render_size = np.full(3, _RENDERED_HEIGHT_IN_2D_MODE)
+ render_size[:len(size)] = size
+
+ self._site.pos = render_pos
+ self._site.size = render_size
+ self._lower_site.pos = self._lower_3d
+ self._mid_site.pos = self._mid_3d
+ self._upper_site.pos = self._upper_3d
+
+ def set_colors(self, rgba, detected_rgba):
+ self.set_color(rgba)
+ self.set_detected_color(detected_rgba)
+
+ def set_color(self, rgba):
+ self._rgba[:3] = rgba
+ self._site.rgba = self._rgba
+
+ def set_detected_color(self, detected_rgba):
+ self._detected_rgba[:3] = detected_rgba
+
+ def set_position(self, physics, pos):
+ physics.bind(self._site).pos = pos
+ size = physics.bind(self._site).size[:3]
+ self._lower = np.array(pos) - np.array(size)
+ self._upper = np.array(pos) + np.array(size)
+
+ self._lower_3d = _ensure_3d(self._lower)
+ self._upper_3d = _ensure_3d(self._upper)
+ self._mid_3d = (self._lower_3d + self._upper_3d) / 2.
+
+ physics.bind(self._lower_site).pos = self._lower_3d
+ physics.bind(self._mid_site).pos = self._mid_3d
+ physics.bind(self._upper_site).pos = self._upper_3d
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ def register_entities(self, *entities):
+ for entity in entities:
+ self._entities.append(_Detection(entity))
+ self._entity_geoms[entity] = entity.mjcf_model.find_all('geom')
+
+ def deregister_entities(self):
+ self._entities = []
+
+ @property
+ def detected_entities(self):
+ """A list of detected entities."""
+ return [
+ detection.entity for detection in self._entities if detection.detected]
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._entity_geoms = {}
+ for detection in self._entities:
+ entity = detection.entity
+ self._entity_geoms[entity] = entity.mjcf_model.find_all('geom')
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._update_detection(physics)
+
+ def before_step(self, physics, unused_random_state):
+ for detection in self._entities:
+ detection.detected = False
+
+ def after_substep(self, physics, unused_random_state):
+ self._update_detection(physics)
+
+ def _is_in_zone(self, xpos):
+ return (np.all(self._lower < xpos[:len(self._lower)])
+ and np.all(self._upper > xpos[:len(self._upper)]))
+
+ def _update_detection(self, physics):
+ self._previously_detected = self._detected
+ self._detected = False
+ for detection in self._entities:
+ if not self._retain_substep_detections:
+ detection.detected = False
+ for geom in self._entity_geoms[detection.entity]:
+ if self._is_in_zone(physics.bind(geom).xpos) != self._inverted:
+ detection.detected = True
+ self._detected = True
+ break
+
+ if self._detected and not self._previously_detected:
+ physics.bind(self._site).rgba = self._detected_rgba
+ elif self._previously_detected and not self._detected:
+ physics.bind(self._site).rgba = self._rgba
+
+ def site_pos(self, physics):
+ return physics.bind(self._site).pos
+
+ @property
+ def activated(self):
+ return self._detected
+
+ @property
+ def upper(self):
+ return self._upper
+
+ @property
+ def lower(self):
+ return self._lower
+
+ @property
+ def mid(self):
+ return (self._lower + self._upper) / 2.
+
+ @property
+ def lower_site(self):
+ return self._lower_site
+
+ @property
+ def mid_site(self):
+ return self._mid_site
+
+ @property
+ def upper_site(self):
+ return self._upper_site
+
+ @property
+ def lower_sensor(self):
+ return self._lower_sensor
+
+ @property
+ def mid_sensor(self):
+ return self._mid_sensor
+
+ @property
+ def upper_sensor(self):
+ return self._upper_sensor
diff --git a/dm_control/entities/props/position_detector_test.py b/dm_control/entities/props/position_detector_test.py
new file mode 100644
index 00000000..43464bbf
--- /dev/null
+++ b/dm_control/entities/props/position_detector_test.py
@@ -0,0 +1,128 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.composer.props.position_detector."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control.entities.props import position_detector
+from dm_control.entities.props import primitive
+import numpy as np
+
+
+class PositionDetectorTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.arena = composer.Arena()
+ self.props = [
+ primitive.Primitive(geom_type='sphere', size=(0.1,)),
+ primitive.Primitive(geom_type='sphere', size=(0.1,))
+ ]
+ for prop in self.props:
+ self.arena.add_free_entity(prop)
+ self.task = composer.NullTask(self.arena)
+
+ def assertDetected(self, entity, detector):
+ if not self.inverted:
+ self.assertIn(entity, detector.detected_entities)
+ else:
+ self.assertNotIn(entity, detector.detected_entities)
+
+ def assertNotDetected(self, entity, detector):
+ if not self.inverted:
+ self.assertNotIn(entity, detector.detected_entities)
+ else:
+ self.assertIn(entity, detector.detected_entities)
+
+ @parameterized.parameters(False, True)
+ def test3DDetection(self, inverted):
+ self.inverted = inverted
+
+ detector_pos = np.array([0.3, 0.2, 0.1])
+ detector_size = np.array([0.1, 0.2, 0.3])
+ detector = position_detector.PositionDetector(
+ pos=detector_pos, size=detector_size, inverted=inverted)
+ detector.register_entities(*self.props)
+ self.arena.attach(detector)
+ env = composer.Environment(self.task)
+
+ env.reset()
+ self.assertNotDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+ def initialize_episode(physics, unused_random_state):
+ for prop in self.props:
+ prop.set_pose(physics, detector_pos)
+ self.task.initialize_episode = initialize_episode
+ env.reset()
+ self.assertDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(env.physics, detector_pos - detector_size)
+ env.step([])
+ self.assertNotDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(env.physics, detector_pos - detector_size / 2)
+ self.props[1].set_pose(env.physics, detector_pos + detector_size * 1.01)
+ env.step([])
+ self.assertDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+ @parameterized.parameters(False, True)
+ def test2DDetection(self, inverted):
+ self.inverted = inverted
+
+ detector_pos = np.array([0.3, 0.2])
+ detector_size = np.array([0.1, 0.2])
+ detector = position_detector.PositionDetector(
+ pos=detector_pos, size=detector_size, inverted=inverted)
+ detector.register_entities(*self.props)
+ self.arena.attach(detector)
+ env = composer.Environment(self.task)
+
+ env.reset()
+ self.assertNotDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+ def initialize_episode(physics, unused_random_state):
+ # In 2D mode, detection should occur no matter how large |z| is.
+ self.props[0].set_pose(physics, [detector_pos[0], detector_pos[1], 1e+6])
+ self.props[1].set_pose(physics, [detector_pos[0], detector_pos[1], -1e+6])
+ self.task.initialize_episode = initialize_episode
+ env.reset()
+ self.assertDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(
+ env.physics, [detector_pos[0] - detector_size[0], detector_pos[1], 0])
+ env.step([])
+ self.assertNotDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(
+ env.physics, [detector_pos[0] - detector_size[0] / 2,
+ detector_pos[1] + detector_size[1] / 2, 0])
+ self.props[1].set_pose(
+ env.physics, [detector_pos[0], detector_pos[1] + detector_size[1], 0])
+ env.step([])
+ self.assertDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/entities/props/primitive.py b/dm_control/entities/props/primitive.py
new file mode 100644
index 00000000..e3389696
--- /dev/null
+++ b/dm_control/entities/props/primitive.py
@@ -0,0 +1,109 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Prop consisting of a single geom with position and velocity sensors."""
+
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer.observation import observable
+
+
+class Primitive(composer.Entity):
+ """A prop consisting of a single geom with position and velocity sensors."""
+
+ def _build(self, geom_type, size, name=None, **kwargs):
+ """Initializes the prop.
+
+ Args:
+ geom_type: String specifying the geom type.
+ size: List or numpy array of up to 3 numbers, depending on `geom_type`:
+ geom_type='box', size=[x_half_length, y_half_length, z_half_length]
+ geom_type='capsule', size=[radius, half_length]
+ geom_type='cylinder', size=[radius, half_length]
+ geom_type='ellipsoid', size=[x_radius, y_radius, z_radius]
+ geom_type='sphere', size=[radius]
+ name: (optional) A string, the name of this prop.
+ **kwargs: Additional geom parameters. Please see the MuJoCo documentation
+ for further details: http://www.mujoco.org/book/XMLreference.html#geom.
+ """
+ self._mjcf_root = mjcf.element.RootElement(model=name)
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', name='geom', type=geom_type, size=size, **kwargs)
+ self._position = self._mjcf_root.sensor.add(
+ 'framepos', name='position', objtype='geom', objname=self.geom)
+ self._orientation = self._mjcf_root.sensor.add(
+ 'framequat', name='orientation', objtype='geom', objname=self.geom)
+ self._linear_velocity = self._mjcf_root.sensor.add(
+ 'framelinvel', name='linear_velocity', objtype='geom',
+ objname=self.geom)
+ self._angular_velocity = self._mjcf_root.sensor.add(
+ 'frameangvel', name='angular_velocity', objtype='geom',
+ objname=self.geom)
+
+ def _build_observables(self):
+ return PrimitiveObservables(self)
+
+ @property
+ def geom(self):
+ """The geom belonging to this prop."""
+ return self._geom
+
+ @property
+ def position(self):
+ """Sensor that returns the prop position."""
+ return self._position
+
+ @property
+ def orientation(self):
+ """Sensor that returns the prop orientation (as a quaternion)."""
+ # TODO(b/120829807): Consider returning a rotation matrix instead.
+ return self._orientation
+
+ @property
+ def linear_velocity(self):
+ """Sensor that returns the linear velocity of the prop."""
+ return self._linear_velocity
+
+ @property
+ def angular_velocity(self):
+ """Sensor that returns the angular velocity of the prop."""
+ return self._angular_velocity
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+
+class PrimitiveObservables(composer.Observables,
+ composer.FreePropObservableMixin):
+ """Primitive entity's observables."""
+
+ @define.observable
+ def position(self):
+ return observable.MJCFFeature('sensordata', self._entity.position)
+
+ @define.observable
+ def orientation(self):
+ return observable.MJCFFeature('sensordata', self._entity.orientation)
+
+ @define.observable
+ def linear_velocity(self):
+ return observable.MJCFFeature('sensordata', self._entity.linear_velocity)
+
+ @define.observable
+ def angular_velocity(self):
+ return observable.MJCFFeature('sensordata', self._entity.angular_velocity)
diff --git a/dm_control/entities/props/primitive_test.py b/dm_control/entities/props/primitive_test.py
new file mode 100644
index 00000000..ae8f3265
--- /dev/null
+++ b/dm_control/entities/props/primitive_test.py
@@ -0,0 +1,97 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.composer.props.primitive."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.entities.props import primitive
+import numpy as np
+
+
+class PrimitiveTest(parameterized.TestCase):
+
+ def _make_free_prop(self, geom_type='sphere', size=(0.1,), **kwargs):
+ prop = primitive.Primitive(geom_type=geom_type, size=size, **kwargs)
+ arena = composer.Arena()
+ arena.add_free_entity(prop)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ return prop, physics
+
+ @parameterized.parameters([
+ dict(geom_type='sphere', size=[0.1]),
+ dict(geom_type='capsule', size=[0.1, 0.2]),
+ dict(geom_type='cylinder', size=[0.1, 0.2]),
+ dict(geom_type='box', size=[0.1, 0.2, 0.3]),
+ dict(geom_type='ellipsoid', size=[0.1, 0.2, 0.3]),
+ ])
+ def test_instantiation(self, geom_type, size):
+ name = 'foo'
+ rgba = [1., 0., 1., 0.5]
+ prop, physics = self._make_free_prop(
+ geom_type=geom_type, size=size, name=name, rgba=rgba)
+ # Check that the name and other kwargs are set correctly.
+ self.assertEqual(prop.mjcf_model.model, name)
+ np.testing.assert_array_equal(physics.bind(prop.geom).rgba, rgba)
+ # Check that we can step without anything breaking.
+ physics.step()
+
+ @parameterized.parameters([
+ dict(position=[0., 0., 0.]),
+ dict(position=[0.1, -0.2, 0.3]),
+ ])
+ def test_position_observable(self, position):
+ prop, physics = self._make_free_prop()
+ prop.set_pose(physics, position=position)
+ observation = prop.observables.position(physics)
+ np.testing.assert_array_equal(position, observation)
+
+ @parameterized.parameters([
+ dict(quat=[1., 0., 0., 0.]),
+ dict(quat=[0., -1., 1., 0.]),
+ ])
+ def test_orientation_observable(self, quat):
+ prop, physics = self._make_free_prop()
+ normalized_quat = np.array(quat) / np.linalg.norm(quat)
+ prop.set_pose(physics, quaternion=normalized_quat)
+ observation = prop.observables.orientation(physics)
+ np.testing.assert_array_almost_equal(normalized_quat, observation)
+
+ @parameterized.parameters([
+ dict(velocity=[0., 0., 0.]),
+ dict(velocity=[0.1, -0.2, 0.3]),
+ ])
+ def test_linear_velocity_observable(self, velocity):
+ prop, physics = self._make_free_prop()
+ prop.set_velocity(physics, velocity=velocity)
+ observation = prop.observables.linear_velocity(physics)
+ np.testing.assert_array_almost_equal(velocity, observation)
+
+ @parameterized.parameters([
+ dict(angular_velocity=[0., 0., 0.]),
+ dict(angular_velocity=[0.1, -0.2, 0.3]),
+ ])
+ def test_angular_velocity_observable(self, angular_velocity):
+ prop, physics = self._make_free_prop()
+ prop.set_velocity(physics, angular_velocity=angular_velocity)
+ observation = prop.observables.angular_velocity(physics)
+ np.testing.assert_array_almost_equal(angular_velocity, observation)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/README.md b/dm_control/locomotion/README.md
new file mode 100644
index 00000000..b149f159
--- /dev/null
+++ b/dm_control/locomotion/README.md
@@ -0,0 +1,91 @@
+# Locomotion task library
+
+This package contains reusable components for defining control tasks that are
+related to locomotion. New users are encouraged to start by browsing the
+`examples/` subdirectory, which contains preconfigured RL environments
+associated with various research papers. These examples can serve as starting
+points or be customized to design new environments using the components
+available from this library.
+
+
+
+
+
+
+## Terminology
+
+This library facilitates the creation of environments that require **walkers**
+to perform a **task** in an **arena**.
+
+- **walkers** refer to detached bodies that can move around in the
+ environment.
+
+- **arenas** refer to the surroundings in which the walkers and possibly other
+ objects exist.
+
+- **tasks** refer to the specification of observations and rewards that are
+ passed from the "environment" to the "agent", along with runtime details
+ such as initialization and termination logic.
+
+## Installation and requirements
+
+See [the documentation for `dm_control`][installation-and-requirements].
+
+## Quickstart
+
+```python
+from dm_control import composer
+from dm_control.locomotion.examples import basic_cmu_2019
+import numpy as np
+
+# Build an example environment.
+env = basic_cmu_2019.cmu_humanoid_run_walls()
+
+# Get the `action_spec` describing the control inputs.
+action_spec = env.action_spec()
+
+# Step through the environment for one episode with random actions.
+time_step = env.reset()
+while not time_step.last():
+ action = np.random.uniform(action_spec.minimum, action_spec.maximum,
+ size=action_spec.shape)
+ time_step = env.step(action)
+ print("reward = {}, discount = {}, observations = {}.".format(
+ time_step.reward, time_step.discount, time_step.observation))
+```
+
+[`dm_control.viewer`] can also be used to visualize and interact with the
+environment, e.g.:
+
+```python
+from dm_control import viewer
+
+viewer.launch(environment_loader=basic_cmu_2019.cmu_humanoid_run_walls)
+```
+
+## Publications
+
+This library contains environments that were adapted from several research
+papers. Relevant references include:
+
+- [Emergence of Locomotion Behaviours in Rich Environments (2017)][heess2017].
+
+- [Learning human behaviors from motion capture by adversarial imitation
+ (2017)][merel2017].
+
+- [Hierarchical visuomotor control of humanoids (2019)][merel2019a].
+
+- [Neural probabilistic motor primitives for humanoid control (2019)][merel2019b].
+
+- [Deep neuroethology of a virtual rodent (2020)][merel2020].
+
+- [CoMic: Complementary Task Learning & Mimicry for Reusable Skills (2020)][hasenclever2020]
+
+[installation-and-requirements]: ../../README.md#installation-and-requirements
+[`dm_control.viewer`]: ../viewer/README.md
+[heess2017]: https://arxiv.org/abs/1707.02286
+[merel2017]: https://arxiv.org/abs/1707.02201
+[merel2019a]: https://arxiv.org/abs/1811.09656
+[merel2019b]: https://arxiv.org/abs/1811.11711
+[merel2020]: https://openreview.net/pdf?id=SyxrxR4KPS
+[hasenclever2020]: http://proceedings.mlr.press/v119/hasenclever20a.html
diff --git a/dm_control/locomotion/__init__.py b/dm_control/locomotion/__init__.py
new file mode 100644
index 00000000..4224c020
--- /dev/null
+++ b/dm_control/locomotion/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/locomotion/arenas/__init__.py b/dm_control/locomotion/arenas/__init__.py
new file mode 100644
index 00000000..026c0eac
--- /dev/null
+++ b/dm_control/locomotion/arenas/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Arenas for Locomotion tasks."""
+
+from dm_control.locomotion.arenas.bowl import Bowl
+from dm_control.locomotion.arenas.corridors import EmptyCorridor
+from dm_control.locomotion.arenas.corridors import GapsCorridor
+from dm_control.locomotion.arenas.corridors import WallsCorridor
+from dm_control.locomotion.arenas.floors import Floor
+from dm_control.locomotion.arenas.labmaze_textures import FloorTextures
+from dm_control.locomotion.arenas.labmaze_textures import SkyBox
+from dm_control.locomotion.arenas.labmaze_textures import WallTextures
+from dm_control.locomotion.arenas.mazes import MazeWithTargets
+from dm_control.locomotion.arenas.mazes import RandomMazeWithTargets
+from dm_control.locomotion.arenas.padded_room import PaddedRoom
diff --git a/dm_control/locomotion/arenas/assets/__init__.py b/dm_control/locomotion/arenas/assets/__init__.py
new file mode 100644
index 00000000..e4d6db18
--- /dev/null
+++ b/dm_control/locomotion/arenas/assets/__init__.py
@@ -0,0 +1,59 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Locomotion texture assets."""
+
+import collections
+import os
+import sys
+
+ROOT_DIR = '../locomotion/arenas/assets'
+
+
+def get_texturedir(style):
+ return os.path.join(ROOT_DIR, style)
+
+SKY_STYLES = ('outdoor_natural')
+
+SkyBox = collections.namedtuple(
+ 'SkyBox', ('file', 'gridsize', 'gridlayout'))
+
+
+def get_sky_texture_info(style):
+ if style not in SKY_STYLES:
+ raise ValueError('`style` should be one of {}: got {!r}'.format(
+ SKY_STYLES, style))
+ return SkyBox(file='OutdoorSkybox2048.png',
+ gridsize='3 4',
+ gridlayout='.U..LFRB.D..')
+
+
+GROUND_STYLES = ('outdoor_natural')
+
+GroundTexture = collections.namedtuple(
+ 'GroundTexture', ('file', 'type'))
+
+
+def get_ground_texture_info(style):
+ if style not in GROUND_STYLES:
+ raise ValueError('`style` should be one of {}: got {!r}'.format(
+ GROUND_STYLES, style))
+ return GroundTexture(
+ file='OutdoorGrassFloorD.png',
+ type='2d')
+
+
+
+
diff --git a/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png b/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png
new file mode 100644
index 00000000..2a93f5b7
Binary files /dev/null and b/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png differ
diff --git a/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png b/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png
new file mode 100644
index 00000000..d6f9a58a
Binary files /dev/null and b/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png differ
diff --git a/dm_control/locomotion/arenas/bowl.py b/dm_control/locomotion/arenas/bowl.py
new file mode 100644
index 00000000..d5fc06be
--- /dev/null
+++ b/dm_control/locomotion/arenas/bowl.py
@@ -0,0 +1,135 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Bowl arena with bumps."""
+
+
+from dm_control import composer
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+from scipy import ndimage
+
+mjlib = mjbindings.mjlib
+
+_TOP_CAMERA_DISTANCE = 100
+_TOP_CAMERA_Y_PADDING_FACTOR = 1.1
+
+# Constants related to terrain generation.
+_TERRAIN_SMOOTHNESS = .5 # 0.0: maximally bumpy; 1.0: completely smooth.
+_TERRAIN_BUMP_SCALE = .2 # Spatial scale of terrain bumps (in meters).
+
+
+class Bowl(composer.Arena):
+ """A bowl arena with sinusoidal bumps."""
+
+ def _build(self, size=(10, 10), aesthetic='default', name='bowl'):
+ super()._build(name=name)
+
+ self._hfield = self._mjcf_root.asset.add(
+ 'hfield',
+ name='terrain',
+ nrow=201,
+ ncol=201,
+ size=(6, 6, 0.5, 0.1))
+
+ if aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+
+ self._texture = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_texture', file=ground_info.file,
+ type=ground_info.type)
+ self._material = self._mjcf_root.asset.add(
+ 'material', name='aesthetic_material', texture=self._texture,
+ texuniform='true')
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+ self._terrain_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ name='terrain',
+ type='hfield',
+ pos=(0, 0, -0.01),
+ hfield='terrain',
+ material=self._material)
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ name='groundplane',
+ size=list(size) + [0.5],
+ material=self._material)
+ else:
+ self._terrain_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ name='terrain',
+ type='hfield',
+ rgba=(0.2, 0.3, 0.4, 1),
+ pos=(0, 0, -0.01),
+ hfield='terrain')
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ name='groundplane',
+ rgba=(0.2, 0.3, 0.4, 1),
+ size=list(size) + [0.5])
+
+ self._mjcf_root.visual.headlight.set_attributes(
+ ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1])
+
+ self._regenerate = True
+
+ def regenerate(self, random_state):
+ # regeneration of the bowl requires physics, so postponed to initialization.
+ self._regenerate = True
+
+ def initialize_episode(self, physics, random_state):
+ if self._regenerate:
+ self._regenerate = False
+
+ # Get heightfield resolution, assert that it is square.
+ res = physics.bind(self._hfield).nrow
+ assert res == physics.bind(self._hfield).ncol
+
+ # Sinusoidal bowl shape.
+ row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j]
+ radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .1, 1)
+ bowl_shape = .5 - np.cos(2*np.pi*radius)/2
+
+ # Random smooth bumps.
+ terrain_size = 2 * physics.bind(self._hfield).size[0]
+ bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
+ bumps = random_state.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
+ smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
+
+ # Terrain is elementwise product.
+ terrain = bowl_shape * smooth_bumps
+ start_idx = physics.bind(self._hfield).adr
+ physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel()
+
+ # If we have a rendering context, we need to re-upload the modified
+ # heightfield data.
+ if physics.contexts:
+ with physics.contexts.gl.make_current() as ctx:
+ ctx.call(mjlib.mjr_uploadHField,
+ physics.model.ptr,
+ physics.contexts.mujoco.ptr,
+ physics.bind(self._hfield).element_id)
+
+ @property
+ def ground_geoms(self):
+ return (self._terrain_geom, self._ground_geom)
diff --git a/dm_control/locomotion/arenas/bowl_test.py b/dm_control/locomotion/arenas/bowl_test.py
new file mode 100644
index 00000000..441266b0
--- /dev/null
+++ b/dm_control/locomotion/arenas/bowl_test.py
@@ -0,0 +1,32 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.arenas.bowl."""
+
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.arenas import bowl
+
+
+class BowlTest(absltest.TestCase):
+
+ def test_can_compile_mjcf(self):
+
+ arena = bowl.Bowl()
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/arenas/corridors.py b/dm_control/locomotion/arenas/corridors.py
new file mode 100644
index 00000000..04eb298b
--- /dev/null
+++ b/dm_control/locomotion/arenas/corridors.py
@@ -0,0 +1,443 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Corridor-based arenas."""
+
+import abc
+
+from dm_control import composer
+from dm_control.composer import variation
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+
+_SIDE_WALLS_GEOM_GROUP = 3
+_CORRIDOR_X_PADDING = 2.0
+_WALL_THICKNESS = 0.16
+_SIDE_WALL_HEIGHT = 4.0
+_DEFAULT_ALPHA = 0.5
+
+
+class Corridor(composer.Arena, metaclass=abc.ABCMeta):
+ """Abstract base class for corridor-type arenas."""
+
+ @abc.abstractmethod
+ def regenerate(self, random_state):
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def corridor_length(self):
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def corridor_width(self):
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def ground_geoms(self):
+ raise NotImplementedError
+
+ def is_at_target_position(self, position, tolerance=0.0):
+ """Checks if a `position` is within `tolerance' of an end of the corridor.
+
+ This can also be used to evaluate more complicated T-shaped or L-shaped
+ corridors.
+
+ Args:
+ position: An iterable of 2 elements corresponding to the x and y location
+ of the position to evaluate.
+ tolerance: A `float` tolerance to use while evaluating the position.
+
+ Returns:
+ A `bool` indicating whether the `position` is within the `tolerance` of an
+ end of the corridor.
+ """
+ x, _ = position
+ return x > self.corridor_length - tolerance
+
+
+class EmptyCorridor(Corridor):
+ """An empty corridor with planes around the perimeter."""
+
+ def _build(self,
+ corridor_width=4,
+ corridor_length=40,
+ visible_side_planes=True,
+ name='empty_corridor'):
+ """Builds the corridor.
+
+ Args:
+ corridor_width: A number or a `composer.variation.Variation` object that
+ specifies the width of the corridor.
+ corridor_length: A number or a `composer.variation.Variation` object that
+ specifies the length of the corridor.
+ visible_side_planes: Whether to the side planes that bound the corridor's
+ perimeter should be rendered.
+ name: The name of this arena.
+ """
+ super()._build(name=name)
+
+ self._corridor_width = corridor_width
+ self._corridor_length = corridor_length
+
+ self._walls_body = self._mjcf_root.worldbody.add('body', name='walls')
+
+ self._mjcf_root.visual.map.znear = 0.0005
+ self._mjcf_root.asset.add(
+ 'texture', type='skybox', builtin='gradient',
+ rgb1=[0.4, 0.6, 0.8], rgb2=[0, 0, 0], width=100, height=600)
+ self._mjcf_root.visual.headlight.set_attributes(
+ ambient=[0.4, 0.4, 0.4], diffuse=[0.8, 0.8, 0.8],
+ specular=[0.1, 0.1, 0.1])
+
+ alpha = _DEFAULT_ALPHA if visible_side_planes else 0.0
+ self._ground_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', rgba=[0.5, 0.5, 0.5, 1], size=[1, 1, 1])
+ self._left_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[1, 0, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+ self._right_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[-1, 0, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+ self._near_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[0, 1, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+ self._far_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[0, -1, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+
+ self._current_corridor_length = None
+ self._current_corridor_width = None
+
+ def regenerate(self, random_state):
+ """Regenerates this corridor.
+
+ New values are drawn from the `corridor_width` and `corridor_height`
+ distributions specified in `_build`. The corridor is resized accordingly.
+
+ Args:
+ random_state: A `numpy.random.RandomState` object that is passed to the
+ `Variation` objects.
+ """
+ self._walls_body.geom.clear()
+ corridor_width = variation.evaluate(self._corridor_width,
+ random_state=random_state)
+ corridor_length = variation.evaluate(self._corridor_length,
+ random_state=random_state)
+ self._current_corridor_length = corridor_length
+ self._current_corridor_width = corridor_width
+
+ self._ground_plane.pos = [corridor_length / 2, 0, 0]
+ self._ground_plane.size = [
+ corridor_length / 2 + _CORRIDOR_X_PADDING, corridor_width / 2, 1]
+
+ self._left_plane.pos = [
+ corridor_length / 2, corridor_width / 2, _SIDE_WALL_HEIGHT / 2]
+ self._left_plane.size = [
+ corridor_length / 2 + _CORRIDOR_X_PADDING, _SIDE_WALL_HEIGHT / 2, 1]
+
+ self._right_plane.pos = [
+ corridor_length / 2, -corridor_width / 2, _SIDE_WALL_HEIGHT / 2]
+ self._right_plane.size = [
+ corridor_length / 2 + _CORRIDOR_X_PADDING, _SIDE_WALL_HEIGHT / 2, 1]
+
+ self._near_plane.pos = [
+ -_CORRIDOR_X_PADDING, 0, _SIDE_WALL_HEIGHT / 2]
+ self._near_plane.size = [corridor_width / 2, _SIDE_WALL_HEIGHT / 2, 1]
+
+ self._far_plane.pos = [
+ corridor_length + _CORRIDOR_X_PADDING, 0, _SIDE_WALL_HEIGHT / 2]
+ self._far_plane.size = [corridor_width / 2, _SIDE_WALL_HEIGHT / 2, 1]
+
+ @property
+ def corridor_length(self):
+ return self._current_corridor_length
+
+ @property
+ def corridor_width(self):
+ return self._current_corridor_width
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_plane,)
+
+
+class GapsCorridor(EmptyCorridor):
+ """A corridor that consists of multiple platforms separated by gaps."""
+
+ # pylint: disable=arguments-renamed
+ def _build(self,
+ platform_length=1.,
+ gap_length=2.5,
+ corridor_width=4,
+ corridor_length=40,
+ ground_rgba=(0.5, 0.5, 0.5, 1),
+ visible_side_planes=False,
+ aesthetic='default',
+ name='gaps_corridor'):
+ """Builds the corridor.
+
+ Args:
+ platform_length: A number or a `composer.variation.Variation` object that
+ specifies the size of the platforms along the corridor.
+ gap_length: A number or a `composer.variation.Variation` object that
+ specifies the size of the gaps along the corridor.
+ corridor_width: A number or a `composer.variation.Variation` object that
+ specifies the width of the corridor.
+ corridor_length: A number or a `composer.variation.Variation` object that
+ specifies the length of the corridor.
+ ground_rgba: A sequence of 4 numbers or a `composer.variation.Variation`
+ object specifying the color of the ground.
+ visible_side_planes: Whether to the side planes that bound the corridor's
+ perimeter should be rendered.
+ aesthetic: option to adjust the material properties and skybox
+ name: The name of this arena.
+ """
+ super()._build(
+ corridor_width=corridor_width,
+ corridor_length=corridor_length,
+ visible_side_planes=visible_side_planes,
+ name=name)
+
+ self._platform_length = platform_length
+ self._gap_length = gap_length
+ self._ground_rgba = ground_rgba
+ self._aesthetic = aesthetic
+
+ if self._aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+
+ self._ground_texture = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_texture', file=ground_info.file,
+ type=ground_info.type)
+ self._ground_material = self._mjcf_root.asset.add(
+ 'material', name='aesthetic_material', texture=self._ground_texture,
+ texuniform='true')
+ # remove existing skybox
+ for texture in self._mjcf_root.asset.find_all('texture'):
+ if texture.type == 'skybox':
+ texture.remove()
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+
+ self._ground_body = self._mjcf_root.worldbody.add('body', name='ground')
+
+ # pylint: enable=arguments-renamed
+
+ def regenerate(self, random_state):
+ """Regenerates this corridor.
+
+ New values are drawn from the `corridor_width` and `corridor_height`
+ distributions specified in `_build`. The corridor resized accordingly, and
+ new sets of platforms are created according to values drawn from the
+ `platform_length`, `gap_length`, and `ground_rgba` distributions specified
+ in `_build`.
+
+ Args:
+ random_state: A `numpy.random.RandomState` object that is passed to the
+ `Variation` objects.
+ """
+ # Resize the entire corridor first.
+ super().regenerate(random_state)
+
+ # Move the ground plane down and make it invisible.
+ self._ground_plane.pos = [self._current_corridor_length / 2, 0, -10]
+ self._ground_plane.rgba = [0, 0, 0, 0]
+
+ # Clear the existing platform pieces.
+ self._ground_body.geom.clear()
+
+ # Make the first platform larger.
+ platform_length = 3. * _CORRIDOR_X_PADDING
+ platform_pos = [
+ platform_length / 2,
+ 0,
+ -_WALL_THICKNESS,
+ ]
+ platform_size = [
+ platform_length / 2,
+ self._current_corridor_width / 2,
+ _WALL_THICKNESS,
+ ]
+ if self._aesthetic != 'default':
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ name='start_floor',
+ pos=platform_pos,
+ size=platform_size,
+ material=self._ground_material)
+ else:
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ rgba=variation.evaluate(self._ground_rgba, random_state),
+ name='start_floor',
+ pos=platform_pos,
+ size=platform_size)
+
+ current_x = platform_length
+ platform_id = 0
+ while current_x < self._current_corridor_length:
+ platform_length = variation.evaluate(
+ self._platform_length, random_state=random_state)
+ platform_pos = [
+ current_x + platform_length / 2.,
+ 0,
+ -_WALL_THICKNESS,
+ ]
+ platform_size = [
+ platform_length / 2,
+ self._current_corridor_width / 2,
+ _WALL_THICKNESS,
+ ]
+ if self._aesthetic != 'default':
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ name='floor_{}'.format(platform_id),
+ pos=platform_pos,
+ size=platform_size,
+ material=self._ground_material)
+ else:
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ rgba=variation.evaluate(self._ground_rgba, random_state),
+ name='floor_{}'.format(platform_id),
+ pos=platform_pos,
+ size=platform_size)
+
+ platform_id += 1
+
+ # Move x to start of the next platform.
+ current_x += platform_length + variation.evaluate(
+ self._gap_length, random_state=random_state)
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_plane,) + tuple(self._ground_body.find_all('geom'))
+
+
+class WallsCorridor(EmptyCorridor):
+ """A corridor obstructed by multiple walls aligned against the two sides."""
+
+ # pylint: disable=arguments-renamed
+ def _build(self,
+ wall_gap=2.5,
+ wall_width=2.5,
+ wall_height=2.0,
+ swap_wall_side=True,
+ wall_rgba=(1, 1, 1, 1),
+ corridor_width=4,
+ corridor_length=40,
+ visible_side_planes=False,
+ include_initial_padding=True,
+ name='walls_corridor'):
+ """Builds the corridor.
+
+ Args:
+ wall_gap: A number or a `composer.variation.Variation` object that
+ specifies the gap between each consecutive pair obstructing walls.
+ wall_width: A number or a `composer.variation.Variation` object that
+ specifies the width that the obstructing walls extend into the corridor.
+ wall_height: A number or a `composer.variation.Variation` object that
+ specifies the height of the obstructing walls.
+ swap_wall_side: A boolean or a `composer.variation.Variation` object that
+ specifies whether the next obstructing wall should be aligned against
+ the opposite side of the corridor compared to the previous one.
+ wall_rgba: A sequence of 4 numbers or a `composer.variation.Variation`
+ object specifying the color of the walls.
+ corridor_width: A number or a `composer.variation.Variation` object that
+ specifies the width of the corridor.
+ corridor_length: A number or a `composer.variation.Variation` object that
+ specifies the length of the corridor.
+ visible_side_planes: Whether to the side planes that bound the corridor's
+ perimeter should be rendered.
+ include_initial_padding: Whether to include initial offset before first
+ obstacle.
+ name: The name of this arena.
+ """
+ super()._build(
+ corridor_width=corridor_width,
+ corridor_length=corridor_length,
+ visible_side_planes=visible_side_planes,
+ name=name)
+
+ self._wall_height = wall_height
+ self._wall_rgba = wall_rgba
+ self._wall_gap = wall_gap
+ self._wall_width = wall_width
+ self._swap_wall_side = swap_wall_side
+ self._include_initial_padding = include_initial_padding
+
+ # pylint: enable=arguments-renamed
+
+ def regenerate(self, random_state):
+ """Regenerates this corridor.
+
+ New values are drawn from the `corridor_width` and `corridor_height`
+ distributions specified in `_build`. The corridor resized accordingly, and
+ new sets of obstructing walls are created according to values drawn from the
+ `wall_gap`, `wall_width`, `wall_height`, and `wall_rgba` distributions
+ specified in `_build`.
+
+ Args:
+ random_state: A `numpy.random.RandomState` object that is passed to the
+ `Variation` objects.
+ """
+ super().regenerate(random_state)
+
+ wall_x = variation.evaluate(
+ self._wall_gap, random_state=random_state) - _CORRIDOR_X_PADDING
+ if self._include_initial_padding:
+ wall_x += 2*_CORRIDOR_X_PADDING
+ wall_side = 0
+ wall_id = 0
+ while wall_x < self._current_corridor_length:
+ wall_width = variation.evaluate(
+ self._wall_width, random_state=random_state)
+ wall_height = variation.evaluate(
+ self._wall_height, random_state=random_state)
+ wall_rgba = variation.evaluate(self._wall_rgba, random_state=random_state)
+ if variation.evaluate(self._swap_wall_side, random_state=random_state):
+ wall_side = 1 - wall_side
+
+ wall_pos = [
+ wall_x,
+ (2 * wall_side - 1) * (self._current_corridor_width - wall_width) / 2,
+ wall_height / 2
+ ]
+ wall_size = [_WALL_THICKNESS / 2, wall_width / 2, wall_height / 2]
+ self._walls_body.add(
+ 'geom',
+ type='box',
+ name='wall_{}'.format(wall_id),
+ pos=wall_pos,
+ size=wall_size,
+ rgba=wall_rgba)
+
+ wall_id += 1
+ wall_x += variation.evaluate(self._wall_gap, random_state=random_state)
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_plane,)
diff --git a/dm_control/locomotion/arenas/corridors_test.py b/dm_control/locomotion/arenas/corridors_test.py
new file mode 100644
index 00000000..600ae612
--- /dev/null
+++ b/dm_control/locomotion/arenas/corridors_test.py
@@ -0,0 +1,85 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.arenas.corridors."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer.variation import deterministic
+from dm_control.locomotion.arenas import corridors
+
+
+class CorridorsTest(parameterized.TestCase):
+
+ @parameterized.parameters([
+ corridors.EmptyCorridor,
+ corridors.GapsCorridor,
+ corridors.WallsCorridor,
+ ])
+ def test_can_compile_mjcf(self, arena_type):
+ arena = arena_type()
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ @parameterized.parameters([
+ corridors.EmptyCorridor,
+ corridors.GapsCorridor,
+ corridors.WallsCorridor,
+ ])
+ def test_can_regenerate_corridor_size(self, arena_type):
+ width_sequence = [5.2, 3.8, 7.4]
+ length_sequence = [21.1, 19.4, 16.3]
+
+ arena = arena_type(
+ corridor_width=deterministic.Sequence(width_sequence),
+ corridor_length=deterministic.Sequence(length_sequence))
+
+ # Add a probe geom that will generate contacts with the side walls.
+ probe_body = arena.mjcf_model.worldbody.add('body', name='probe')
+ probe_joint = probe_body.add('freejoint')
+ probe_geom = probe_body.add('geom', name='probe', type='box')
+
+ for expected_width, expected_length in zip(width_sequence, length_sequence):
+ # No random_state is required since we are using deterministic variations.
+ arena.regenerate(random_state=None)
+
+ def resize_probe_geom_and_assert_num_contacts(
+ delta_size, expected_num_contacts,
+ expected_width=expected_width, expected_length=expected_length):
+ probe_geom.size = [
+ (expected_length / 2 + corridors._CORRIDOR_X_PADDING) + delta_size,
+ expected_width / 2 + delta_size, 0.1]
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ probe_geomid = physics.bind(probe_geom).element_id
+ physics.bind(probe_joint).qpos[:3] = [expected_length / 2, 0, 100]
+ physics.forward()
+ probe_contacts = [c for c in physics.data.contact
+ if c.geom1 == probe_geomid or c.geom2 == probe_geomid]
+ self.assertLen(probe_contacts, expected_num_contacts)
+
+ epsilon = 1e-7
+
+ # If the probe geom is epsilon-smaller than the expected corridor size,
+ # then we expect to detect no contact.
+ resize_probe_geom_and_assert_num_contacts(-epsilon, 0)
+
+ # If the probe geom is epsilon-larger than the expected corridor size,
+ # then we expect to generate 4 contacts with each side wall, so 16 total.
+ resize_probe_geom_and_assert_num_contacts(epsilon, 16)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/arenas/covering.py b/dm_control/locomotion/arenas/covering.py
new file mode 100644
index 00000000..95df587f
--- /dev/null
+++ b/dm_control/locomotion/arenas/covering.py
@@ -0,0 +1,137 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Calculates a covering of text mazes with overlapping rectangular walls."""
+
+import collections
+import numpy as np
+
+GridCoordinates = collections.namedtuple('GridCoordinates', ('y', 'x'))
+MazeWall = collections.namedtuple('MazeWall', ('start', 'end'))
+
+
+class _MazeWallCoveringContext:
+ """Calculates a covering of text mazes with overlapping rectangular walls.
+
+ This class uses a greedy algorithm to try and minimize the number of geoms
+ generated to create a given maze. The solution is not guaranteed to be
+ optimal, but in most cases should result in a significantly smaller number of
+ geoms than if each cell were treated as an individual box.
+ """
+
+ def __init__(self, text_maze, wall_char='*', make_odd_sized_walls=False):
+ """Initializes this _MazeWallCoveringContext.
+
+ Args:
+ text_maze: A `labmaze.TextGrid` instance.
+ wall_char: (optional) The character that signifies a wall.
+ make_odd_sized_walls: (optional) A boolean, if `True` all wall sections
+ generated span odd numbers of grid cells. This option exists primarily
+ to appease MuJoCo's texture repeating algorithm.
+ """
+ self._text_maze = text_maze
+ self._wall_char = wall_char
+ self._make_odd_sized_walls = make_odd_sized_walls
+ self._covered = np.full(text_maze.shape, False, dtype=bool)
+ self._maze_size = GridCoordinates(*text_maze.shape)
+ self._next_start = GridCoordinates(0, 0)
+ self._calculated = False
+ self._walls = ()
+
+ def calculate(self):
+ """Calculates a covering of text mazes with overlapping rectangular walls.
+
+ Returns:
+ A tuple of `MazeWall` objects, each describing the corners of a wall.
+ """
+ if not self._calculated:
+ self._calculated = True
+ self._find_next_start()
+ walls = []
+ while self._next_start.y < self._maze_size.y:
+ walls.append(self._find_next_wall())
+ self._find_next_start()
+ self._walls = tuple(walls)
+ return self._walls
+
+ def _find_next_start(self):
+ """Moves `self._next_start` to the top-left corner of the next wall."""
+ for y in range(self._next_start.y, self._maze_size.y):
+ start_x = self._next_start.x if y == self._next_start.y else 0
+ for x in range(start_x, self._maze_size.x):
+ if self._text_maze[y, x] == self._wall_char and not self._covered[y, x]:
+ self._next_start = GridCoordinates(y, x)
+ return
+ self._next_start = self._maze_size
+
+ def _scan_row(self, row, start_col, end_col):
+ """Scans a row of text maze to find the longest strip of wall."""
+ for col in range(start_col, end_col):
+ if (self._text_maze[row, col] != self._wall_char
+ or self._covered[row, col]):
+ return col
+ return end_col
+
+ def _find_next_wall(self):
+ """Finds the largest piece of rectangular wall at the current location.
+
+ This function assumes that `self._next_start` is already at the top-left
+ corner of the next piece of wall.
+
+ Returns:
+ A `MazeWall` named tuple representing the next piece of wall created.
+ """
+ start = self._next_start
+ x = self._maze_size.x
+ end_x_for_rows = []
+ total_cells = []
+
+ for y in range(start.y, self._maze_size.y):
+ x = self._scan_row(y, start.x, x)
+ if x > start.x:
+ if self._make_odd_sized_walls and (x - start.x) % 2 == 0:
+ x -= 1
+ end_x_for_rows.append(x)
+ total_cells.append((x - start.x) * (y - start.y + 1))
+ y += 1
+ else:
+ break
+
+ if not self._make_odd_sized_walls:
+ end_y_offset = total_cells.index(max(total_cells))
+ else:
+ end_y_offset = 2 * total_cells[::2].index(max(total_cells[::2]))
+ end = GridCoordinates(start.y + end_y_offset + 1,
+ end_x_for_rows[end_y_offset])
+ self._covered[start.y:end.y, start.x:end.x] = True
+ self._next_start = GridCoordinates(start.y, end.x)
+ return MazeWall(start, end)
+
+
+def make_walls(text_maze, wall_char='*', make_odd_sized_walls=False):
+ """Calculates a covering of text mazes with overlapping rectangular walls.
+
+ Args:
+ text_maze: A `labmaze.TextMaze` instance.
+ wall_char: (optional) The character that signifies a wall.
+ make_odd_sized_walls: (optional) A boolean, if `True` all wall sections
+ generated span odd numbers of grid cells. This option exists primarily
+ to appease MuJoCo's texture repeating algorithm.
+
+ Returns:
+ A tuple of `MazeWall` objects, each describing the corners of a wall.
+ """
+ wall_covering_context = _MazeWallCoveringContext(
+ text_maze, wall_char=wall_char, make_odd_sized_walls=make_odd_sized_walls)
+ return wall_covering_context.calculate()
diff --git a/dm_control/locomotion/arenas/covering_test.py b/dm_control/locomotion/arenas/covering_test.py
new file mode 100644
index 00000000..9e838423
--- /dev/null
+++ b/dm_control/locomotion/arenas/covering_test.py
@@ -0,0 +1,72 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for arenas.mazes.covering."""
+
+
+from absl.testing import absltest
+from dm_control.locomotion.arenas import covering
+import labmaze
+import numpy as np
+
+_STRING_DTYPE = '|U1'
+
+
+class CoveringTest(absltest.TestCase):
+
+ def testRandomMazes(self):
+ maze = labmaze.RandomMaze(height=17, width=17,
+ max_rooms=5, room_min_size=3, room_max_size=5,
+ spawns_per_room=0, objects_per_room=0,
+ random_seed=54321)
+ for _ in range(1000):
+ maze.regenerate()
+ walls = covering.make_walls(maze.entity_layer)
+ reconstructed = np.full(maze.entity_layer.shape, ' ', dtype=_STRING_DTYPE)
+ for wall in walls:
+ reconstructed[wall.start.y:wall.end.y, wall.start.x:wall.end.x] = '*'
+ np.testing.assert_array_equal(reconstructed, maze.entity_layer)
+
+ def testOddCovering(self):
+ maze = labmaze.RandomMaze(height=17, width=17,
+ max_rooms=5, room_min_size=3, room_max_size=5,
+ spawns_per_room=0, objects_per_room=0,
+ random_seed=54321)
+ for _ in range(1000):
+ maze.regenerate()
+ walls = covering.make_walls(maze.entity_layer, make_odd_sized_walls=True)
+ reconstructed = np.full(maze.entity_layer.shape, ' ', dtype=_STRING_DTYPE)
+ for wall in walls:
+ reconstructed[wall.start.y:wall.end.y, wall.start.x:wall.end.x] = '*'
+ np.testing.assert_array_equal(reconstructed, maze.entity_layer)
+ for wall in walls:
+ self.assertEqual((wall.end.y - wall.start.y) % 2, 1)
+ self.assertEqual((wall.end.x - wall.start.x) % 2, 1)
+
+ def testNoOverlappingWalls(self):
+ maze_string = """..**
+ .***
+ .***
+ """.replace(' ', '')
+ walls = covering.make_walls(labmaze.TextGrid(maze_string))
+ surface = 0
+ for wall in walls:
+ size_x = wall.end.x - wall.start.x
+ size_y = wall.end.y - wall.start.y
+ surface += size_x * size_y
+ self.assertEqual(surface, 8)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/arenas/floors.py b/dm_control/locomotion/arenas/floors.py
new file mode 100644
index 00000000..2a3f19d3
--- /dev/null
+++ b/dm_control/locomotion/arenas/floors.py
@@ -0,0 +1,104 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Simple floor arenas."""
+
+
+from dm_control import composer
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+import numpy as np
+
+_GROUNDPLANE_QUAD_SIZE = 0.25
+
+
+class Floor(composer.Arena):
+ """A simple floor arena with a checkered pattern."""
+
+ def _build(self, size=(8, 8), reflectance=.2, aesthetic='default',
+ name='floor', top_camera_y_padding_factor=1.1,
+ top_camera_distance=100):
+ super()._build(name=name)
+ self._size = size
+ self._top_camera_y_padding_factor = top_camera_y_padding_factor
+ self._top_camera_distance = top_camera_distance
+
+ self._mjcf_root.visual.headlight.set_attributes(
+ ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1])
+
+ if aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+
+ self._ground_texture = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_texture', file=ground_info.file,
+ type=ground_info.type)
+ self._ground_material = self._mjcf_root.asset.add(
+ 'material', name='aesthetic_material', texture=self._ground_texture,
+ texuniform='true')
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+ else:
+ self._ground_texture = self._mjcf_root.asset.add(
+ 'texture',
+ rgb1=[.2, .3, .4],
+ rgb2=[.1, .2, .3],
+ type='2d',
+ builtin='checker',
+ name='groundplane',
+ width=200,
+ height=200,
+ mark='edge',
+ markrgb=[0.8, 0.8, 0.8])
+ self._ground_material = self._mjcf_root.asset.add(
+ 'material',
+ name='groundplane',
+ texrepeat=[2, 2], # Makes white squares exactly 1x1 length units.
+ texuniform=True,
+ reflectance=reflectance,
+ texture=self._ground_texture)
+
+ # Build groundplane.
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ name='groundplane',
+ material=self._ground_material,
+ size=list(size) + [_GROUNDPLANE_QUAD_SIZE])
+
+ # Choose the FOV so that the floor always fits nicely within the frame
+ # irrespective of actual floor size.
+ fovy_radians = 2 * np.arctan2(top_camera_y_padding_factor * size[1],
+ top_camera_distance)
+ self._top_camera = self._mjcf_root.worldbody.add(
+ 'camera',
+ name='top_camera',
+ pos=[0, 0, top_camera_distance],
+ quat=[1, 0, 0, 0],
+ fovy=np.rad2deg(fovy_radians))
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_geom,)
+
+ def regenerate(self, random_state):
+ pass
+
+ @property
+ def size(self):
+ return self._size
diff --git a/dm_control/locomotion/arenas/floors_test.py b/dm_control/locomotion/arenas/floors_test.py
new file mode 100644
index 00000000..24e5fa5e
--- /dev/null
+++ b/dm_control/locomotion/arenas/floors_test.py
@@ -0,0 +1,50 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.arenas.floors."""
+
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.arenas import floors
+import numpy as np
+
+
+class FloorsTest(absltest.TestCase):
+
+ def test_can_compile_mjcf(self):
+ arena = floors.Floor()
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ def test_size(self):
+ floor_size = (12.9, 27.1)
+ arena = floors.Floor(size=floor_size)
+ self.assertEqual(tuple(arena.ground_geoms[0].size[:2]), floor_size)
+
+ def test_top_camera(self):
+ floor_width, floor_height = 12.9, 27.1
+ arena = floors.Floor(size=[floor_width, floor_height])
+
+ self.assertGreater(arena._top_camera_y_padding_factor, 1)
+ np.testing.assert_array_equal(arena._top_camera.quat, (1, 0, 0, 0))
+
+ expected_camera_y = floor_height * arena._top_camera_y_padding_factor
+ np.testing.assert_allclose(
+ np.tan(np.deg2rad(arena._top_camera.fovy / 2)),
+ expected_camera_y / arena._top_camera.pos[2])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/arenas/labmaze_textures.py b/dm_control/locomotion/arenas/labmaze_textures.py
new file mode 100644
index 00000000..dd29577c
--- /dev/null
+++ b/dm_control/locomotion/arenas/labmaze_textures.py
@@ -0,0 +1,83 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""LabMaze textures."""
+
+
+from dm_control import composer
+from dm_control import mjcf
+from labmaze import assets as labmaze_assets
+
+
+class SkyBox(composer.Entity):
+ """Represents a texture asset for the sky box."""
+
+ def _build(self, style):
+ labmaze_textures = labmaze_assets.get_sky_texture_paths(style)
+ self._mjcf_root = mjcf.RootElement(model='labmaze_' + style)
+ self._texture = self._mjcf_root.asset.add(
+ 'texture', type='skybox', name='texture',
+ fileleft=labmaze_textures.left, fileright=labmaze_textures.right,
+ fileup=labmaze_textures.up, filedown=labmaze_textures.down,
+ filefront=labmaze_textures.front, fileback=labmaze_textures.back)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def texture(self):
+ return self._texture
+
+
+class WallTextures(composer.Entity):
+ """Represents wall texture assets."""
+
+ def _build(self, style):
+ labmaze_textures = labmaze_assets.get_wall_texture_paths(style)
+ self._mjcf_root = mjcf.RootElement(model='labmaze_' + style)
+ self._textures = []
+ for texture_name, texture_path in labmaze_textures.items():
+ self._textures.append(self._mjcf_root.asset.add(
+ 'texture', type='2d', name=texture_name,
+ file=texture_path.format(texture_name)))
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def textures(self):
+ return self._textures
+
+
+class FloorTextures(composer.Entity):
+ """Represents floor texture assets."""
+
+ def _build(self, style):
+ labmaze_textures = labmaze_assets.get_floor_texture_paths(style)
+ self._mjcf_root = mjcf.RootElement(model='labmaze_' + style)
+ self._textures = []
+ for texture_name, texture_path in labmaze_textures.items():
+ self._textures.append(self._mjcf_root.asset.add(
+ 'texture', type='2d', name=texture_name,
+ file=texture_path.format(texture_name)))
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def textures(self):
+ return self._textures
diff --git a/dm_control/locomotion/arenas/mazes.py b/dm_control/locomotion/arenas/mazes.py
new file mode 100644
index 00000000..804c24cb
--- /dev/null
+++ b/dm_control/locomotion/arenas/mazes.py
@@ -0,0 +1,460 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Maze-based arenas."""
+
+import string
+
+from absl import logging
+from dm_control import composer
+from dm_control.composer.observation import observable
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+from dm_control.locomotion.arenas import covering
+import labmaze
+import numpy as np
+
+
+# Put all "actual" wall geoms in a separate group since they are not rendered.
+_WALL_GEOM_GROUP = 3
+
+_TOP_CAMERA_DISTANCE = 100
+_TOP_CAMERA_Y_PADDING_FACTOR = 1.1
+
+_DEFAULT_WALL_CHAR = '*'
+_DEFAULT_FLOOR_CHAR = '.'
+
+
+class MazeWithTargets(composer.Arena):
+ """A 2D maze with target positions specified by a LabMaze-style text maze."""
+
+ def _build(self, maze, xy_scale=2.0, z_height=2.0,
+ skybox_texture=None, wall_textures=None, floor_textures=None,
+ aesthetic='default', name='maze'):
+ """Initializes this maze arena.
+
+ Args:
+ maze: A `labmaze.BaseMaze` instance.
+ xy_scale: The size of each maze cell in metres.
+ z_height: The z-height of the maze in metres.
+ skybox_texture: (optional) A `composer.Entity` that provides a texture
+ asset for the skybox.
+ wall_textures: (optional) Either a `composer.Entity` that provides texture
+ assets for the maze walls, or a dict mapping printable characters to
+ such Entities. In the former case, the maze walls are assumed to be
+ represented by '*' in the maze's entity layer. In the latter case,
+ the dict's keys specify the different characters that can be present
+ in the maze's entity layer, and the dict's values are the corresponding
+ texture providers.
+ floor_textures: (optional) A `composer.Entity` that provides texture
+ assets for the maze floor. Unlike with walls, we do not currently
+ support per-variation floor texture. Instead, we sample textures from
+ the same texture provider for each variation in the variations layer.
+ aesthetic: option to adjust the material properties and skybox
+ name: (optional) A string, the name of this arena.
+ """
+ super()._build(name)
+ self._maze = maze
+ self._xy_scale = xy_scale
+ self._z_height = z_height
+
+ self._x_offset = (self._maze.width - 1) / 2
+ self._y_offset = (self._maze.height - 1) / 2
+
+ self._mjcf_root.default.geom.rgba = [1, 1, 1, 1]
+
+ if aesthetic != 'default':
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+ elif skybox_texture:
+ self._skybox_texture = skybox_texture.texture
+ self.attach(skybox_texture)
+ else:
+ self._skybox_texture = self._mjcf_root.asset.add(
+ 'texture', type='skybox', name='skybox', builtin='gradient',
+ rgb1=[.4, .6, .8], rgb2=[0, 0, 0], width=100, height=100)
+
+ self._texturing_geom_names = []
+ self._texturing_material_names = []
+ if wall_textures:
+ if isinstance(wall_textures, dict):
+ for texture_provider in set(wall_textures.values()):
+ self.attach(texture_provider)
+ self._wall_textures = {
+ wall_char: texture_provider.textures
+ for wall_char, texture_provider in wall_textures.items()
+ }
+ else:
+ self.attach(wall_textures)
+ self._wall_textures = {_DEFAULT_WALL_CHAR: wall_textures.textures}
+ else:
+ self._wall_textures = {_DEFAULT_WALL_CHAR: [self._mjcf_root.asset.add(
+ 'texture', type='2d', name='wall', builtin='flat',
+ rgb1=[.8, .8, .8], width=100, height=100)]}
+
+ if aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ self._floor_textures = [
+ self._mjcf_root.asset.add(
+ 'texture',
+ name='aesthetic_texture_main',
+ file=ground_info.file,
+ type=ground_info.type),
+ self._mjcf_root.asset.add(
+ 'texture',
+ name='aesthetic_texture',
+ file=ground_info.file,
+ type=ground_info.type)
+ ]
+ elif floor_textures:
+ self._floor_textures = floor_textures.textures
+ self.attach(floor_textures)
+ else:
+ self._floor_textures = [self._mjcf_root.asset.add(
+ 'texture', type='2d', name='floor', builtin='flat',
+ rgb1=[.2, .2, .2], width=100, height=100)]
+
+ ground_x = ((self._maze.width - 1) + 1) * (xy_scale / 2)
+ ground_y = ((self._maze.height - 1) + 1) * (xy_scale / 2)
+ self._mjcf_root.worldbody.add(
+ 'geom', name='ground', type='plane',
+ pos=[0, 0, 0], size=[ground_x, ground_y, 1], rgba=[0, 0, 0, 0])
+
+ self._maze_body = self._mjcf_root.worldbody.add('body', name='maze_body')
+
+ self._mjcf_root.visual.map.znear = 0.0005
+
+ # Choose the FOV so that the maze always fits nicely within the frame
+ # irrespective of actual maze size.
+ maze_size = max(self._maze.width, self._maze.height)
+ top_camera_fovy = (360 / np.pi) * np.arctan2(
+ _TOP_CAMERA_Y_PADDING_FACTOR * maze_size * self._xy_scale / 2,
+ _TOP_CAMERA_DISTANCE)
+ self._top_camera = self._mjcf_root.worldbody.add(
+ 'camera', name='top_camera',
+ pos=[0, 0, _TOP_CAMERA_DISTANCE], zaxis=[0, 0, 1], fovy=top_camera_fovy)
+
+ self._target_positions = ()
+ self._spawn_positions = ()
+
+ self._text_maze_regenerated_hook = None
+ self._tile_geom_names = {}
+
+ def _build_observables(self):
+ return MazeObservables(self)
+
+ @property
+ def top_camera(self):
+ return self._top_camera
+
+ @property
+ def xy_scale(self):
+ return self._xy_scale
+
+ @property
+ def z_height(self):
+ return self._z_height
+
+ @property
+ def maze(self):
+ return self._maze
+
+ @property
+ def text_maze_regenerated_hook(self):
+ """A callback that is executed after the LabMaze object is regenerated."""
+ return self._text_maze_modifier
+
+ @text_maze_regenerated_hook.setter
+ def text_maze_regenerated_hook(self, hook):
+ self._text_maze_regenerated_hook = hook
+
+ @property
+ def target_positions(self):
+ """A tuple of Cartesian target positions generated for the current maze."""
+ return self._target_positions
+
+ @property
+ def spawn_positions(self):
+ """The Cartesian position at which the agent should be spawned."""
+ return self._spawn_positions
+
+ @property
+ def target_grid_positions(self):
+ """A tuple of grid coordinates of targets generated for the current maze."""
+ return self._target_grid_positions
+
+ @property
+ def spawn_grid_positions(self):
+ """The grid-coordinate position at which the agent should be spawned."""
+ return self._spawn_grid_positions
+
+ def regenerate(self, random_state=np.random.RandomState()):
+ """Generates a new maze layout."""
+ del random_state
+ self._maze.regenerate()
+ logging.debug('GENERATED MAZE:\n%s', self._maze.entity_layer)
+ self._find_spawn_and_target_positions()
+
+ if self._text_maze_regenerated_hook:
+ self._text_maze_regenerated_hook()
+
+ # Remove old texturing planes.
+ for geom_name in self._texturing_geom_names:
+ del self._mjcf_root.worldbody.geom[geom_name]
+ self._texturing_geom_names = []
+
+ # Remove old texturing materials.
+ for material_name in self._texturing_material_names:
+ del self._mjcf_root.asset.material[material_name]
+ self._texturing_material_names = []
+
+ # Remove old actual-wall geoms.
+ self._maze_body.geom.clear()
+
+ self._current_wall_texture = {
+ wall_char: np.random.choice(wall_textures)
+ for wall_char, wall_textures in self._wall_textures.items()
+ }
+
+ for wall_char in self._wall_textures:
+ self._make_wall_geoms(wall_char)
+ self._make_floor_variations()
+
+ def _make_wall_geoms(self, wall_char):
+ walls = covering.make_walls(
+ self._maze.entity_layer, wall_char=wall_char, make_odd_sized_walls=True)
+ for i, wall in enumerate(walls):
+ wall_mid = covering.GridCoordinates(
+ (wall.start.y + wall.end.y - 1) / 2,
+ (wall.start.x + wall.end.x - 1) / 2)
+ wall_pos = np.array([(wall_mid.x - self._x_offset) * self._xy_scale,
+ -(wall_mid.y - self._y_offset) * self._xy_scale,
+ self._z_height / 2])
+ wall_size = np.array([(wall.end.x - wall_mid.x - 0.5) * self._xy_scale,
+ (wall.end.y - wall_mid.y - 0.5) * self._xy_scale,
+ self._z_height / 2])
+ self._maze_body.add('geom', name='wall{}_{}'.format(wall_char, i),
+ type='box', pos=wall_pos, size=wall_size,
+ group=_WALL_GEOM_GROUP)
+ self._make_wall_texturing_planes(wall_char, i, wall_pos, wall_size)
+
+ def _make_wall_texturing_planes(self, wall_char, wall_id,
+ wall_pos, wall_size):
+ xyaxes = {
+ 'x': {-1: [0, -1, 0, 0, 0, 1], 1: [0, 1, 0, 0, 0, 1]},
+ 'y': {-1: [1, 0, 0, 0, 0, 1], 1: [-1, 0, 0, 0, 0, 1]},
+ 'z': {-1: [-1, 0, 0, 0, 1, 0], 1: [1, 0, 0, 0, 1, 0]}
+ }
+ for direction_index, direction in enumerate(('x', 'y', 'z')):
+ index = list(i for i in range(3) if i != direction_index)
+ delta_vector = np.array([int(i == direction_index) for i in range(3)])
+ material_name = 'wall{}_{}_{}'.format(wall_char, wall_id, direction)
+ self._texturing_material_names.append(material_name)
+ mat = self._mjcf_root.asset.add(
+ 'material', name=material_name,
+ texture=self._current_wall_texture[wall_char],
+ texrepeat=(2 * wall_size[index] / self._xy_scale))
+ for sign, sign_name in zip((-1, 1), ('neg', 'pos')):
+ if direction == 'z' and sign == -1:
+ continue
+ geom_name = (
+ 'wall{}_{}_texturing_{}_{}'.format(
+ wall_char, wall_id, sign_name, direction))
+ self._texturing_geom_names.append(geom_name)
+ self._mjcf_root.worldbody.add(
+ 'geom', type='plane', name=geom_name,
+ pos=(wall_pos + sign * delta_vector * wall_size),
+ size=np.concatenate([wall_size[index], [self._xy_scale]]),
+ xyaxes=xyaxes[direction][sign], material=mat,
+ contype=0, conaffinity=0)
+
+ def _make_floor_variations(self, build_tile_geoms_fn=None):
+ """Builds the floor tiles.
+
+ Args:
+ build_tile_geoms_fn: An optional callable returning floor tile geoms.
+ If not passed, the floor will be built using a default covering method.
+ Takes a kwarg `wall_char` that can be used control how active floor
+ tiles are selected.
+ """
+ main_floor_texture = np.random.choice(self._floor_textures)
+ for variation in _DEFAULT_FLOOR_CHAR + string.ascii_uppercase:
+ if variation not in self._maze.variations_layer:
+ break
+
+ if build_tile_geoms_fn is None:
+ # Break the floor variation down to odd-sized tiles.
+ tiles = covering.make_walls(self._maze.variations_layer,
+ wall_char=variation,
+ make_odd_sized_walls=True)
+ else:
+ tiles = build_tile_geoms_fn(wall_char=variation)
+
+ # Sample a texture that's not the same as the main floor texture.
+ variation_texture = main_floor_texture
+ if variation != _DEFAULT_FLOOR_CHAR:
+ if len(self._floor_textures) == 1:
+ return
+ else:
+ while variation_texture is main_floor_texture:
+ variation_texture = np.random.choice(self._floor_textures)
+
+ for i, tile in enumerate(tiles):
+ tile_mid = covering.GridCoordinates(
+ (tile.start.y + tile.end.y - 1) / 2,
+ (tile.start.x + tile.end.x - 1) / 2)
+ tile_pos = np.array([(tile_mid.x - self._x_offset) * self._xy_scale,
+ -(tile_mid.y - self._y_offset) * self._xy_scale,
+ 0.0])
+ tile_size = np.array([(tile.end.x - tile_mid.x - 0.5) * self._xy_scale,
+ (tile.end.y - tile_mid.y - 0.5) * self._xy_scale,
+ self._xy_scale])
+ if variation == _DEFAULT_FLOOR_CHAR:
+ tile_name = 'floor_{}'.format(i)
+ else:
+ tile_name = 'floor_{}_{}'.format(variation, i)
+ self._tile_geom_names[tile.start] = tile_name
+ self._texturing_material_names.append(tile_name)
+ self._texturing_geom_names.append(tile_name)
+ material = self._mjcf_root.asset.add(
+ 'material', name=tile_name, texture=variation_texture,
+ texrepeat=(2 * tile_size[[0, 1]] / self._xy_scale))
+ self._mjcf_root.worldbody.add(
+ 'geom', name=tile_name, type='plane', material=material,
+ pos=tile_pos, size=tile_size, contype=0, conaffinity=0)
+
+ @property
+ def ground_geoms(self):
+ return tuple([
+ geom for geom in self.mjcf_model.find_all('geom')
+ if 'ground' in geom.name
+ ])
+
+ def find_token_grid_positions(self, tokens):
+ out = {token: [] for token in tokens}
+ for y in range(self._maze.entity_layer.shape[0]):
+ for x in range(self._maze.entity_layer.shape[1]):
+ for token in tokens:
+ if self._maze.entity_layer[y, x] == token:
+ out[token].append((y, x))
+ return out
+
+ def grid_to_world_positions(self, grid_positions):
+ out = []
+ for y, x in grid_positions:
+ out.append(np.array([(x - self._x_offset) * self._xy_scale,
+ -(y - self._y_offset) * self._xy_scale,
+ 0.0]))
+ return out
+
+ def world_to_grid_positions(self, world_positions):
+ out = []
+ # the order of x, y is reverse between grid positions format and
+ # world positions format.
+ for x, y, _ in world_positions:
+ out.append(np.array([self._y_offset - y / self._xy_scale,
+ self._x_offset + x / self._xy_scale]))
+ return out
+
+ def _find_spawn_and_target_positions(self):
+ grid_positions = self.find_token_grid_positions([
+ labmaze.defaults.OBJECT_TOKEN, labmaze.defaults.SPAWN_TOKEN])
+ self._target_grid_positions = tuple(
+ grid_positions[labmaze.defaults.OBJECT_TOKEN])
+ self._spawn_grid_positions = tuple(
+ grid_positions[labmaze.defaults.SPAWN_TOKEN])
+ self._target_positions = tuple(
+ self.grid_to_world_positions(self._target_grid_positions))
+ self._spawn_positions = tuple(
+ self.grid_to_world_positions(self._spawn_grid_positions))
+
+
+class MazeObservables(composer.Observables):
+
+ @composer.observable
+ def top_camera(self):
+ return observable.MJCFCamera(self._entity.top_camera)
+
+
+class RandomMazeWithTargets(MazeWithTargets):
+ """A randomly generated 2D maze with target positions."""
+
+ def _build(self,
+ x_cells,
+ y_cells,
+ xy_scale=2.0,
+ z_height=2.0,
+ max_rooms=labmaze.defaults.MAX_ROOMS,
+ room_min_size=labmaze.defaults.ROOM_MIN_SIZE,
+ room_max_size=labmaze.defaults.ROOM_MAX_SIZE,
+ spawns_per_room=labmaze.defaults.SPAWN_COUNT,
+ targets_per_room=labmaze.defaults.OBJECT_COUNT,
+ max_variations=labmaze.defaults.MAX_VARIATIONS,
+ simplify=labmaze.defaults.SIMPLIFY,
+ skybox_texture=None,
+ wall_textures=None,
+ floor_textures=None,
+ aesthetic='default',
+ name='random_maze'):
+ """Initializes this random maze arena.
+
+ Args:
+ x_cells: The number of cells along the x-direction of the maze. Must be
+ an odd integer.
+ y_cells: The number of cells along the y-direction of the maze. Must be
+ an odd integer.
+ xy_scale: The size of each maze cell in metres.
+ z_height: The z-height of the maze in metres.
+ max_rooms: (optional) The maximum number of rooms in each generated maze.
+ room_min_size: (optional) The minimum size of each room generated.
+ room_max_size: (optional) The maximum size of each room generated.
+ spawns_per_room: (optional) Number of spawn points
+ to generate in each room.
+ targets_per_room: (optional) Number of targets to generate in each room.
+ max_variations: (optional) Maximum number of variations to generate
+ in the variations layer.
+ simplify: (optional) flag to simplify the maze.
+ skybox_texture: (optional) A `composer.Entity` that provides a texture
+ asset for the skybox.
+ wall_textures: (optional) A `composer.Entity` that provides texture
+ assets for the maze walls.
+ floor_textures: (optional) A `composer.Entity` that provides texture
+ assets for the maze floor.
+ aesthetic: option to adjust the material properties and skybox
+ name: (optional) A string, the name of this arena.
+ """
+ random_seed = np.random.randint(2147483648) # 2**31
+ super()._build(
+ maze=labmaze.RandomMaze(
+ height=y_cells,
+ width=x_cells,
+ max_rooms=max_rooms,
+ room_min_size=room_min_size,
+ room_max_size=room_max_size,
+ max_variations=max_variations,
+ spawns_per_room=spawns_per_room,
+ objects_per_room=targets_per_room,
+ simplify=simplify,
+ random_seed=random_seed),
+ xy_scale=xy_scale,
+ z_height=z_height,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ aesthetic=aesthetic,
+ name=name)
diff --git a/dm_control/locomotion/arenas/mazes_test.py b/dm_control/locomotion/arenas/mazes_test.py
new file mode 100644
index 00000000..82b080d9
--- /dev/null
+++ b/dm_control/locomotion/arenas/mazes_test.py
@@ -0,0 +1,49 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.arenas.mazes."""
+
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+
+
+class MazesTest(absltest.TestCase):
+
+ def test_can_compile_mjcf(self):
+
+ # Set the wall and floor textures to match DMLab and set the skybox.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures)
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/arenas/padded_room.py b/dm_control/locomotion/arenas/padded_room.py
new file mode 100644
index 00000000..3584c046
--- /dev/null
+++ b/dm_control/locomotion/arenas/padded_room.py
@@ -0,0 +1,81 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A LabMaze square room where the outermost cells are always empty."""
+
+import labmaze
+import numpy as np
+_PADDING = 4
+
+
+class PaddedRoom(labmaze.BaseMaze):
+ """A LabMaze square room where the outermost cells are always empty."""
+
+ def __init__(self,
+ room_size,
+ num_objects=0,
+ random_state=None,
+ pad_with_walls=True,
+ num_agent_spawn_positions=1):
+ self._room_size = room_size
+ self._num_objects = num_objects
+ self._num_agent_spawn_positions = num_agent_spawn_positions
+ self._random_state = random_state or np.random
+
+ empty_maze = '\n'.join(['.' * (room_size + _PADDING)] *
+ (room_size + _PADDING) + [''])
+
+ self._entity_layer = labmaze.TextGrid(empty_maze)
+
+ if pad_with_walls:
+ self._entity_layer[0, :] = '*'
+ self._entity_layer[-1, :] = '*'
+ self._entity_layer[:, 0] = '*'
+ self._entity_layer[:, -1] = '*'
+
+ self._variations_layer = labmaze.TextGrid(empty_maze)
+
+ def regenerate(self):
+ self._entity_layer[1:-1, 1:-1] = ' '
+ self._variations_layer[:, :] = '.'
+
+ generated = list(
+ self._random_state.choice(
+ self._room_size * self._room_size,
+ self._num_objects + self._num_agent_spawn_positions,
+ replace=False))
+ for i, obj in enumerate(generated):
+ if i < self._num_agent_spawn_positions:
+ token = labmaze.defaults.SPAWN_TOKEN
+ else:
+ token = labmaze.defaults.OBJECT_TOKEN
+ obj_y, obj_x = obj // self._room_size, obj % self._room_size
+ self._entity_layer[obj_y + int(_PADDING / 2),
+ obj_x + int(_PADDING / 2)] = token
+
+ @property
+ def entity_layer(self):
+ return self._entity_layer
+
+ @property
+ def variations_layer(self):
+ return self._variations_layer
+
+ @property
+ def width(self):
+ return self._room_size + _PADDING
+
+ @property
+ def height(self):
+ return self._room_size + _PADDING
diff --git a/dm_control/locomotion/arenas/padded_room_test.py b/dm_control/locomotion/arenas/padded_room_test.py
new file mode 100644
index 00000000..25574f63
--- /dev/null
+++ b/dm_control/locomotion/arenas/padded_room_test.py
@@ -0,0 +1,42 @@
+# Copyright 2021 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.arenas.padded_room."""
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+from dm_control.locomotion.arenas import padded_room
+
+
+class PaddedRoomTest(absltest.TestCase):
+
+ def test_can_compile_mjcf(self):
+ # Set the wall and floor textures to match DMLab and set the skybox.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+
+ maze = padded_room.PaddedRoom(room_size=4, num_objects=2)
+ arena = mazes.MazeWithTargets(
+ maze=maze,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures)
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/examples/__init__.py b/dm_control/locomotion/examples/__init__.py
new file mode 100644
index 00000000..06e95be7
--- /dev/null
+++ b/dm_control/locomotion/examples/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Walkers for Locomotion tasks."""
+
+from dm_control.locomotion.walkers.cmu_humanoid import CMUHumanoid
+from dm_control.locomotion.walkers.cmu_humanoid import CMUHumanoidPositionControlled
+from dm_control.locomotion.walkers.cmu_humanoid import CMUHumanoidPositionControlledV2020
diff --git a/dm_control/locomotion/examples/basic_cmu_2019.py b/dm_control/locomotion/examples/basic_cmu_2019.py
new file mode 100644
index 00000000..02b3faba
--- /dev/null
+++ b/dm_control/locomotion/examples/basic_cmu_2019.py
@@ -0,0 +1,222 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Produces reference environments for CMU humanoid locomotion tasks."""
+
+import functools
+
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.tasks import go_to_target
+from dm_control.locomotion.tasks import random_goal_maze
+from dm_control.locomotion.walkers import cmu_humanoid
+from labmaze import fixed_maze
+
+
+def cmu_humanoid_run_walls(random_state=None):
+ """Requires a CMU humanoid to run down a corridor obstructed by walls."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a corridor-shaped arena that is obstructed by walls.
+ arena = corr_arenas.WallsCorridor(
+ wall_gap=4.,
+ wall_width=distributions.Uniform(1, 7),
+ wall_height=3.0,
+ corridor_width=10,
+ corridor_length=100,
+ include_initial_padding=False)
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(0.5, 0, 0),
+ target_velocity=3.0,
+ physics_timestep=0.005,
+ control_timestep=0.03)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_run_gaps(random_state=None):
+ """Requires a CMU humanoid to run down a corridor with gaps."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a corridor-shaped arena with gaps, where the sizes of the gaps and
+ # platforms are uniformly randomized.
+ arena = corr_arenas.GapsCorridor(
+ platform_length=distributions.Uniform(.3, 2.5),
+ gap_length=distributions.Uniform(.5, 1.25),
+ corridor_width=10,
+ corridor_length=100)
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(0.5, 0, 0),
+ target_velocity=3.0,
+ physics_timestep=0.005,
+ control_timestep=0.03)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_go_to_target(random_state=None):
+ """Requires a CMU humanoid to go to a target."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled()
+
+ # Build a standard floor arena.
+ arena = floors.Floor()
+
+ # Build a task that rewards the agent for going to a target.
+ task = go_to_target.GoToTarget(
+ walker=walker,
+ arena=arena,
+ physics_timestep=0.005,
+ control_timestep=0.03)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_maze_forage(random_state=None):
+ """Requires a CMU humanoid to find all items in a maze."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a maze with rooms and targets.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+
+ # Build a task that rewards the agent for obtaining targets.
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ target_reward_scale=50.,
+ physics_timestep=0.005,
+ control_timestep=0.03,
+ )
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_heterogeneous_forage(random_state=None):
+ """Requires a CMU humanoid to find all items of a particular type in a maze."""
+ level = ('*******\n'
+ '* *\n'
+ '* P *\n'
+ '* *\n'
+ '* G *\n'
+ '* *\n'
+ '*******\n')
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ maze = fixed_maze.FixedMazeWithRandomGoals(
+ entity_layer=level,
+ variations_layer=None,
+ num_spawns=1,
+ num_objects=6,
+ )
+ arena = mazes.MazeWithTargets(
+ maze=maze,
+ xy_scale=3.0,
+ z_height=2.0,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+ task = random_goal_maze.ManyHeterogeneousGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builders=[
+ functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0.4, 0),
+ rgb2=(0, 0.7, 0)),
+ functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0.4, 0, 0),
+ rgb2=(0.7, 0, 0)),
+ ],
+ randomize_spawn_rotation=False,
+ target_type_rewards=[30., -10.],
+ target_type_proportions=[1, 1],
+ shuffle_target_builders=True,
+ aliveness_reward=0.01,
+ control_timestep=.03,
+ )
+
+ return composer.Environment(
+ time_limit=25,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
diff --git a/dm_control/locomotion/examples/basic_rodent_2020.py b/dm_control/locomotion/examples/basic_rodent_2020.py
new file mode 100644
index 00000000..51bd8f33
--- /dev/null
+++ b/dm_control/locomotion/examples/basic_rodent_2020.py
@@ -0,0 +1,171 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Produces reference environments for rodent tasks."""
+
+import functools
+
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.locomotion.arenas import bowl
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.tasks import escape
+from dm_control.locomotion.tasks import random_goal_maze
+from dm_control.locomotion.tasks import reach
+from dm_control.locomotion.walkers import rodent
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+def rodent_escape_bowl(random_state=None):
+ """Requires a rodent to climb out of a bowl-shaped terrain."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a bowl-shaped arena.
+ arena = bowl.Bowl(
+ size=(20., 20.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for being far from the origin.
+ task = escape.Escape(
+ walker=walker,
+ arena=arena,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ return composer.Environment(time_limit=20,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def rodent_run_gaps(random_state=None):
+ """Requires a rodent to run down a corridor with gaps."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a corridor-shaped arena with gaps, where the sizes of the gaps and
+ # platforms are uniformly randomized.
+ arena = corr_arenas.GapsCorridor(
+ platform_length=distributions.Uniform(.4, .8),
+ gap_length=distributions.Uniform(.05, .2),
+ corridor_width=2,
+ corridor_length=40,
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(5, 0, 0),
+ walker_spawn_rotation=0,
+ target_velocity=1.0,
+ contact_termination=False,
+ terminate_at_height=-0.3,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def rodent_maze_forage(random_state=None):
+ """Requires a rodent to find all items in a maze."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a maze with rooms and targets.
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=.5,
+ z_height=.3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ wall_textures=wall_textures,
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for obtaining targets.
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.05,
+ height_above_ground=.125,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ target_reward_scale=50.,
+ contact_termination=False,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def rodent_two_touch(random_state=None):
+ """Requires a rodent to tap an orb, wait an interval, and tap it again."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build an open floor arena
+ arena = floors.Floor(
+ size=(10., 10.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the walker for touching/reaching orbs with a
+ # specific time interval between touches
+ task = reach.TwoTouch(
+ walker=walker,
+ arena=arena,
+ target_builders=[
+ functools.partial(target_sphere.TargetSphereTwoTouch, radius=0.025),
+ ],
+ randomize_spawn_rotation=True,
+ target_type_rewards=[25.],
+ shuffle_target_builders=False,
+ target_area=(1.5, 1.5),
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP,
+ )
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
diff --git a/dm_control/locomotion/examples/cmu_2020_tracking.py b/dm_control/locomotion/examples/cmu_2020_tracking.py
new file mode 100644
index 00000000..b279570a
--- /dev/null
+++ b/dm_control/locomotion/examples/cmu_2020_tracking.py
@@ -0,0 +1,53 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Produces reference environments for CMU humanoid tracking task."""
+
+
+from dm_control import composer
+from dm_control.locomotion import arenas
+
+from dm_control.locomotion.mocap import cmu_mocap_data
+from dm_control.locomotion.tasks.reference_pose import tracking
+
+from dm_control.locomotion.walkers import cmu_humanoid
+
+
+def cmu_humanoid_tracking(random_state=None):
+ """Requires a CMU humanoid to run down a corridor obstructed by walls."""
+
+ # Use a position-controlled CMU humanoid walker.
+ walker_type = cmu_humanoid.CMUHumanoidPositionControlledV2020
+
+ # Build an empty arena.
+ arena = arenas.Floor()
+
+ # Build a task that rewards the agent for tracking motion capture reference
+ # data.
+ task = tracking.MultiClipMocapTracking(
+ walker=walker_type,
+ arena=arena,
+ ref_path=cmu_mocap_data.get_path_for_cmu(version='2020'),
+ dataset='walk_tiny',
+ ref_steps=(1, 2, 3, 4, 5),
+ min_steps=10,
+ reward_type='comic',
+ )
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
diff --git a/dm_control/locomotion/examples/examples_test.py b/dm_control/locomotion/examples/examples_test.py
new file mode 100644
index 00000000..a264f534
--- /dev/null
+++ b/dm_control/locomotion/examples/examples_test.py
@@ -0,0 +1,82 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for `dm_control.locomotion.examples`."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.locomotion.examples import basic_cmu_2019
+from dm_control.locomotion.examples import basic_rodent_2020
+import numpy as np
+
+
+_NUM_EPISODES = 5
+_NUM_STEPS_PER_EPISODE = 10
+
+
+class ExampleEnvironmentsTest(parameterized.TestCase):
+ """Tests run on all the tasks registered."""
+
+ def _validate_observation(self, observation, observation_spec):
+ self.assertEqual(list(observation.keys()), list(observation_spec.keys()))
+ for name, array_spec in observation_spec.items():
+ array_spec.validate(observation[name])
+
+ def _validate_reward_range(self, reward):
+ self.assertIsInstance(reward, float)
+ self.assertBetween(reward, 0, 1)
+
+ def _validate_discount(self, discount):
+ self.assertIsInstance(discount, float)
+ self.assertBetween(discount, 0, 1)
+
+ @parameterized.named_parameters(
+ ('cmu_humanoid_run_walls', basic_cmu_2019.cmu_humanoid_run_walls),
+ ('cmu_humanoid_run_gaps', basic_cmu_2019.cmu_humanoid_run_gaps),
+ ('cmu_humanoid_go_to_target', basic_cmu_2019.cmu_humanoid_go_to_target),
+ ('cmu_humanoid_maze_forage', basic_cmu_2019.cmu_humanoid_maze_forage),
+ ('cmu_humanoid_heterogeneous_forage',
+ basic_cmu_2019.cmu_humanoid_heterogeneous_forage),
+ ('rodent_escape_bowl', basic_rodent_2020.rodent_escape_bowl),
+ ('rodent_run_gaps', basic_rodent_2020.rodent_run_gaps),
+ ('rodent_maze_forage', basic_rodent_2020.rodent_maze_forage),
+ ('rodent_two_touch', basic_rodent_2020.rodent_two_touch),
+ )
+ def test_env_runs(self, env_constructor):
+ """Tests that the environment runs and is coherent with its specs."""
+ random_state = np.random.RandomState(99)
+
+ env = env_constructor(random_state=random_state)
+ observation_spec = env.observation_spec()
+ action_spec = env.action_spec()
+ self.assertTrue(np.all(np.isfinite(action_spec.minimum)))
+ self.assertTrue(np.all(np.isfinite(action_spec.maximum)))
+
+ # Run a partial episode, check observations, rewards, discount.
+ for _ in range(_NUM_EPISODES):
+ time_step = env.reset()
+ for _ in range(_NUM_STEPS_PER_EPISODE):
+ self._validate_observation(time_step.observation, observation_spec)
+ if time_step.first():
+ self.assertIsNone(time_step.reward)
+ self.assertIsNone(time_step.discount)
+ else:
+ self._validate_reward_range(time_step.reward)
+ self._validate_discount(time_step.discount)
+ action = random_state.uniform(action_spec.minimum, action_spec.maximum)
+ env.step(action)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/examples/explore.py b/dm_control/locomotion/examples/explore.py
new file mode 100644
index 00000000..971aa0dd
--- /dev/null
+++ b/dm_control/locomotion/examples/explore.py
@@ -0,0 +1,28 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Simple script to launch viewer with an example environment."""
+
+from absl import app
+
+from dm_control.locomotion.examples import basic_cmu_2019
+from dm_control import viewer
+
+
+def main(unused_argv):
+ viewer.launch(environment_loader=basic_cmu_2019.cmu_humanoid_run_gaps)
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/locomotion/gaps.png b/dm_control/locomotion/gaps.png
new file mode 100644
index 00000000..9a6c16d0
Binary files /dev/null and b/dm_control/locomotion/gaps.png differ
diff --git a/dm_control/locomotion/mocap/__init__.py b/dm_control/locomotion/mocap/__init__.py
new file mode 100644
index 00000000..8a363bda
--- /dev/null
+++ b/dm_control/locomotion/mocap/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/locomotion/mocap/cmu_mocap_data.py b/dm_control/locomotion/mocap/cmu_mocap_data.py
new file mode 100644
index 00000000..37f1bb96
--- /dev/null
+++ b/dm_control/locomotion/mocap/cmu_mocap_data.py
@@ -0,0 +1,114 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""File loader for DeepMind-preprocessed version of the CMU motion capture data.
+
+The raw CMU point-cloud data is fitted onto the `dm_control.locomotion` CMU
+Humanoid walker model and re-exported as an HDF5 file
+(https://www.hdfgroup.org/solutions/hdf5/) that can be read by the
+`dm_control.locomotion.mocap` package.
+
+The original database is produced and hosted by Carnegie Mellon University at
+http://mocap.cs.cmu.edu/, and may be copied, modified, or redistributed without
+explicit permission (see http://mocap.cs.cmu.edu/faqs.php).
+"""
+
+import hashlib
+import os
+
+import requests
+import tqdm
+
+H5_FILENAME = {'2019': 'cmu_2019_08756c01.h5',
+ '2020': 'cmu_2020_dfe3e9e0.h5'}
+
+H5_PATHS = {k: (os.path.join(os.path.dirname(__file__), v),
+ os.path.join('~/.dm_control', v))
+ for k, v in H5_FILENAME.items()}
+H5_URL_BASE = 'https://storage.googleapis.com/dm_control/'
+H5_URL = {'2019': H5_URL_BASE+'cmu_2019_08756c01.h5',
+ '2020': H5_URL_BASE+'cmu_2020_dfe3e9e0.h5'}
+
+H5_BYTES = {'2019': 488143314,
+ '2020': 476559420}
+H5_SHA256 = {
+ '2019': '08756c01cb4ac20da9918e70e85c32d4880c6c8c16189b02a18b79a5e79afa2b',
+ '2020': 'dfe3e9e0b08d32960bdafbf89e541339ca8908a9a5e7f4a2c986362890d72863'}
+
+
+def _get_cached_file_path(version):
+ """Returns the path to the cached data file if one exists."""
+ for path in H5_PATHS[version]:
+ expanded_path = os.path.expanduser(path)
+ try:
+ if os.path.getsize(expanded_path) != H5_BYTES[version]:
+ continue
+ with open(expanded_path, 'rb'):
+ return expanded_path
+ except IOError:
+ continue
+ return None
+
+
+def _download_and_cache(version):
+ """Downloads CMU data into one of the candidate paths in H5_PATHS."""
+ for path in H5_PATHS[version]:
+ expanded_path = os.path.expanduser(path)
+ try:
+ os.makedirs(os.path.dirname(expanded_path), exist_ok=True)
+ f = open(expanded_path, 'wb+')
+ except IOError:
+ continue
+ with f:
+ try:
+ _download_into_file(f, version)
+ except:
+ os.unlink(expanded_path)
+ raise
+ return expanded_path
+ raise IOError('cannot open file to write download data into, '
+ f'paths attempted: {H5_PATHS[version]}')
+
+
+def _download_into_file(f, version, validate_hash=True):
+ """Download the CMU data into a file object that has been opened for write."""
+ with requests.get(H5_URL[version], stream=True) as req:
+ req.raise_for_status()
+ total_bytes = int(req.headers['Content-Length'])
+ progress_bar = tqdm.tqdm(
+ desc='Downloading CMU mocap data', total=total_bytes,
+ unit_scale=True, unit_divisor=1024)
+ try:
+ for chunk in req.iter_content(chunk_size=102400):
+ if chunk:
+ f.write(chunk)
+ progress_bar.update(len(chunk))
+ finally:
+ progress_bar.close()
+
+ if validate_hash:
+ f.seek(0)
+ if hashlib.sha256(f.read()).hexdigest() != H5_SHA256[version]:
+ raise RuntimeError('downloaded file is corrupted')
+
+
+def get_path_for_cmu(version='2020'):
+ """Path to mocap data fitted to a version of the CMU Humanoid model."""
+ assert version in H5_FILENAME.keys()
+ path = _get_cached_file_path(version)
+ if path is None:
+ path = _download_and_cache(version)
+ return path
+
diff --git a/dm_control/locomotion/mocap/loader.py b/dm_control/locomotion/mocap/loader.py
new file mode 100644
index 00000000..7587c8d8
--- /dev/null
+++ b/dm_control/locomotion/mocap/loader.py
@@ -0,0 +1,246 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Helpers for loading a collection of trajectories."""
+
+import abc
+import collections
+import operator
+
+from dm_control.composer import variation
+from dm_control.locomotion.mocap import mocap_pb2
+from dm_control.locomotion.mocap import trajectory
+from dm_control.utils import transformations as tr
+from google.protobuf import descriptor
+import numpy as np
+
+
+class TrajectoryLoader(metaclass=abc.ABCMeta):
+ """Base class for helpers that load and decode mocap trajectories."""
+
+ def __init__(self, trajectory_class=trajectory.Trajectory,
+ proto_modifier=()):
+ """Initializes this loader.
+
+ Args:
+ trajectory_class: A Python class that wraps a loaded trajectory proto.
+ proto_modifier: (optional) A callable, or an iterable of callables, that
+ modify each trajectory proto in-place after it has been deserialized
+ from the SSTable.
+
+ Raises:
+ ValueError: If `proto_modifier` is specified, but contains a
+ non-callable entry.
+ """
+ self._trajectory_class = trajectory_class
+ if not isinstance(proto_modifier, collections.abc.Iterable):
+ if proto_modifier is None: # backwards compatibility
+ proto_modifier = ()
+ else:
+ proto_modifier = (proto_modifier,)
+ for modifier in proto_modifier:
+ if not callable(modifier):
+ raise ValueError('{} is not callable'.format(modifier))
+ self._proto_modifiers = proto_modifier
+
+ @abc.abstractmethod
+ def keys(self):
+ """The sequence of identifiers for the loaded trajectories."""
+
+ @abc.abstractmethod
+ def _get_proto_for_key(self, key):
+ """Returns a protocol buffer message corresponding to the requested key."""
+
+ def get_trajectory(self, key, start_time=None, end_time=None, start_step=None,
+ end_step=None, zero_out_velocities=True):
+ """Retrieves a trajectory identified by `key` from the SSTable."""
+ proto = self._get_proto_for_key(key)
+ for modifier in self._proto_modifiers:
+ modifier(proto)
+ return self._trajectory_class(proto, start_time=start_time,
+ end_time=end_time, start_step=start_step,
+ end_step=end_step,
+ zero_out_velocities=zero_out_velocities)
+
+
+class HDF5TrajectoryLoader(TrajectoryLoader):
+ """A helper for loading and decoding mocap trajectories from HDF5.
+
+ In order to use this class, h5py must be installed (it's an optional
+ dependency of dm_control).
+ """
+
+ def __init__(self, path, trajectory_class=trajectory.Trajectory,
+ proto_modifier=()):
+ # h5py is an optional dependency of dm_control, so only try to import
+ # if it's used.
+ try:
+ import h5py # pylint: disable=g-import-not-at-top
+ except ImportError as e:
+ raise ImportError(
+ 'h5py not found. When installing dm_control, '
+ 'use `pip install dm_control[HDF5]` to enable HDF5TrajectoryLoader.'
+ ) from e
+ self._h5_file = h5py.File(path, mode='r')
+ self._keys = tuple(sorted(self._h5_file.keys()))
+ super().__init__(
+ trajectory_class=trajectory_class, proto_modifier=proto_modifier)
+
+ def keys(self):
+ return self._keys
+
+ def _fill_primitive_proto_fields(self, proto, h5_group, skip_fields=()):
+ for field in proto.DESCRIPTOR.fields:
+ if field.name in skip_fields or field.name not in h5_group.attrs:
+ continue
+ elif field.type not in (descriptor.FieldDescriptor.TYPE_GROUP,
+ descriptor.FieldDescriptor.TYPE_MESSAGE):
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ getattr(proto, field.name).extend(h5_group.attrs[field.name])
+ else:
+ setattr(proto, field.name, h5_group.attrs[field.name])
+
+ def _fill_repeated_proto_message_fields(self, proto_container,
+ h5_container, h5_prefix):
+ for item_id in range(len(h5_container)):
+ h5_item = h5_container['{:s}_{:d}'.format(h5_prefix, item_id)]
+ proto = proto_container.add()
+ self._fill_primitive_proto_fields(proto, h5_item)
+
+ def _get_proto_for_key(self, key):
+ """Returns a trajectory protocol buffer message for the specified key."""
+ if isinstance(key, str):
+ key = key.encode('utf-8')
+
+ h5_trajectory = self._h5_file[key]
+ num_steps = h5_trajectory.attrs['num_steps']
+
+ proto = mocap_pb2.FittedTrajectory()
+ proto.identifier = key
+ self._fill_primitive_proto_fields(proto, h5_trajectory,
+ skip_fields=('identifier',))
+
+ for _ in range(num_steps):
+ proto.timesteps.add()
+
+ h5_walkers = h5_trajectory['walkers']
+ for walker_id in range(len(h5_walkers)):
+ h5_walker = h5_walkers['walker_{:d}'.format(walker_id)]
+ walker_proto = proto.walkers.add()
+ self._fill_primitive_proto_fields(walker_proto, h5_walker)
+ self._fill_repeated_proto_message_fields(
+ walker_proto.scaling.subtree,
+ h5_walker['scaling'], h5_prefix='subtree')
+ self._fill_repeated_proto_message_fields(
+ walker_proto.markers.marker,
+ h5_walker['markers'], h5_prefix='marker')
+
+ walker_fields = dict()
+ for field in mocap_pb2.WalkerPose.DESCRIPTOR.fields:
+ walker_fields[field.name] = np.asarray(h5_walker[field.name])
+
+ for timestep_id, timestep in enumerate(proto.timesteps):
+ walker_timestep = timestep.walkers.add()
+ for k, v in walker_fields.items():
+ getattr(walker_timestep, k).extend(v[:, timestep_id])
+
+ h5_props = h5_trajectory['props']
+ for prop_id in range(len(h5_props)):
+ h5_prop = h5_props['prop_{:d}'.format(prop_id)]
+ prop_proto = proto.props.add()
+ self._fill_primitive_proto_fields(prop_proto, h5_prop)
+
+ prop_fields = dict()
+ for field in mocap_pb2.PropPose.DESCRIPTOR.fields:
+ prop_fields[field.name] = np.asarray(h5_prop[field.name])
+
+ for timestep_id, timestep in enumerate(proto.timesteps):
+ prop_timestep = timestep.props.add()
+ for k, v in prop_fields.items():
+ getattr(prop_timestep, k).extend(v[:, timestep_id])
+
+ return proto
+
+
+class PropMassLimiter:
+ """A trajectory proto modifier that enforces a maximum mass for each prop."""
+
+ def __init__(self, max_mass):
+ self._max_mass = max_mass
+
+ def __call__(self, proto, random_state=None):
+ for prop in proto.props:
+ prop.mass = min(prop.mass, self._max_mass)
+
+
+class PropResizer:
+ """A trajectory proto modifier that changes prop sizes and mass."""
+
+ def __init__(self, size_factor=None, size_delta=None, mass=None):
+ if size_factor and size_delta:
+ raise ValueError(
+ 'Only one of `size_factor` or `size_delta` can be specified.')
+ elif size_factor:
+ self._size_variation = size_factor
+ self._size_op = operator.mul
+ else:
+ self._size_variation = size_delta
+ self._size_op = operator.add
+ self._mass = mass
+
+ def __call__(self, proto, random_state=None):
+ for prop in proto.props:
+ size_value = variation.evaluate(self._size_variation,
+ random_state=random_state)
+ if not np.shape(size_value):
+ size_value = np.full(len(prop.size), size_value)
+ for i in range(len(prop.size)):
+ prop.size[i] = self._size_op(prop.size[i], size_value[i])
+ prop.mass = variation.evaluate(self._mass, random_state=random_state)
+
+
+class ZOffsetter:
+ """A trajectory proto modifier that shifts the z position of a trajectory."""
+
+ def __init__(self, z_offset=0.0):
+ self._z_offset = z_offset
+
+ def _add_z_offset(self, proto_field):
+ if len(proto_field) % 3:
+ raise ValueError('Length of proto_field is not a multiple of 3.')
+ for i in range(2, len(proto_field), 3):
+ proto_field[i] += self._z_offset
+
+ def __call__(self, proto, random_state=None):
+ for t in proto.timesteps:
+ for walker_pose in t.walkers:
+ # shift walker position.
+ self._add_z_offset(walker_pose.position)
+ self._add_z_offset(walker_pose.body_positions)
+ self._add_z_offset(walker_pose.center_of_mass)
+ for prop_pose in t.props:
+ # shift prop position
+ self._add_z_offset(prop_pose.position)
+
+
+class AppendageFixer:
+
+ def __call__(self, proto, random_state=None):
+ for t in proto.timesteps:
+ for walker_pose in t.walkers:
+ xpos = np.asarray(walker_pose.position)
+ xquat = np.asarray(walker_pose.quaternion)
+ appendages = np.reshape(walker_pose.appendages, (-1, 3))
+ xmat = tr.quat_to_mat(xquat)[:3, :3]
+ walker_pose.appendages[:] = np.ravel((appendages - xpos) @ xmat)
diff --git a/dm_control/locomotion/mocap/loader_test.py b/dm_control/locomotion/mocap/loader_test.py
new file mode 100644
index 00000000..9f18ed15
--- /dev/null
+++ b/dm_control/locomotion/mocap/loader_test.py
@@ -0,0 +1,77 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for loader."""
+
+import os
+
+from absl.testing import absltest
+from dm_control.locomotion.mocap import loader
+from dm_control.locomotion.mocap import mocap_pb2
+from dm_control.locomotion.mocap import trajectory
+from google.protobuf import descriptor
+from google.protobuf import text_format
+
+from dm_control.utils import io as resources
+
+TEXTPROTOS = [
+ os.path.join(os.path.dirname(__file__), 'test_001.textproto'),
+ os.path.join(os.path.dirname(__file__), 'test_002.textproto'),
+]
+
+HDF5 = os.path.join(os.path.dirname(__file__), 'test_trajectories.h5')
+
+
+class HDF5TrajectoryLoaderTest(absltest.TestCase):
+
+ def assert_proto_equal(self, x, y, msg=''):
+ self.assertEqual(type(x), type(y), msg=msg)
+ for field in x.DESCRIPTOR.fields:
+ x_field = getattr(x, field.name)
+ y_field = getattr(y, field.name)
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
+ for i, (x_child, y_child) in enumerate(zip(x_field, y_field)):
+ self.assert_proto_equal(
+ x_child, y_child,
+ msg=os.path.join(msg, '{}[{}]'.format(field.name, i)))
+ else:
+ self.assertEqual(list(x_field), list(y_field),
+ msg=os.path.join(msg, field.name))
+ else:
+ if field.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
+ self.assert_proto_equal(
+ x_field, y_field, msg=os.path.join(msg, field.name))
+ else:
+ self.assertEqual(x_field, y_field, msg=os.path.join(msg, field.name))
+
+ def test_hdf5_agrees_with_textprotos(self):
+
+ hdf5_loader = loader.HDF5TrajectoryLoader(
+ resources.GetResourceFilename(HDF5))
+
+ for textproto_path in TEXTPROTOS:
+ trajectory_textproto = resources.GetResource(textproto_path)
+ trajectory_from_textproto = mocap_pb2.FittedTrajectory()
+ text_format.Parse(trajectory_textproto, trajectory_from_textproto)
+
+ trajectory_identifier = (
+ trajectory_from_textproto.identifier.encode('utf-8'))
+ self.assert_proto_equal(
+ hdf5_loader.get_trajectory(trajectory_identifier)._proto,
+ trajectory.Trajectory(trajectory_from_textproto)._proto)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/mocap/mocap.proto b/dm_control/locomotion/mocap/mocap.proto
new file mode 100644
index 00000000..9f370e7b
--- /dev/null
+++ b/dm_control/locomotion/mocap/mocap.proto
@@ -0,0 +1,192 @@
+syntax = "proto3";
+
+package dm_control.locomotion.mocap;
+
+// A motion capture tracking marker.
+message Marker {
+ // The name that identifies this marker.
+ string name = 1;
+
+ // The name of the parent frame to which this marker is attached.
+ string parent = 2;
+
+ // The position of this marker within the parent frame.
+ repeated double position = 3 [packed = true];
+
+ // The orientation of this marker within the parent frame.
+ repeated double quaternion = 4 [packed = true];
+
+ // NEXT_ID: 5
+}
+
+// A collection of Markers.
+message Markers {
+ repeated Marker marker = 1;
+}
+
+// Scaling information for a single subtree within a walker's model.
+message SubtreeScaling {
+ // The name of the body at which this scaling is specified.
+ string body_name = 1;
+
+ // The desired length of the parent at which the above body is attached.
+ double parent_length = 2;
+
+ // The factor by which to scale the size of each geom in this subtree.
+ double size_factor = 3;
+
+ // NEXT_ID: 5
+}
+
+// Scaling information for a walker's model.
+message WalkerScaling {
+ repeated SubtreeScaling subtree = 1;
+}
+
+message Walker {
+ // A name that identifies this walker.
+ string name = 1;
+
+ enum Model {
+ UNSPECIFIED = 0;
+ CMU_2019 = 1;
+ RESERVED_MODEL_ID_2 = 2;
+ RESERVED_MODEL_ID_3 = 3;
+ CMU_2020 = 4;
+ RESERVED_MODEL_ID_5 = 5;
+ RESERVED_MODEL_ID_6 = 6;
+ // NEXT_ID: 7
+ }
+ Model model = 2;
+
+ // Factors used to scale the base model to fit the mocap actor.
+ // Scaling must be applied in the same order as listed in the proto.
+ WalkerScaling scaling = 3;
+
+ // Mocap markers placed on this walker.
+ Markers markers = 4;
+
+ // Total mass of the walker, in kilograms.
+ double mass = 5;
+
+ // Names of end effectors as present in WalkerPose.
+ repeated string end_effector_names = 6;
+
+ // Names of appendages as present in WalkerPose.
+ repeated string appendage_names = 7;
+
+ // NEXT_ID: 8
+}
+
+message Prop {
+ // A name that identifies this prop.
+ string name = 1;
+
+ enum Shape {
+ UNSPECIFIED = 0;
+ SPHERE = 1;
+ BOX = 2;
+ }
+ Shape shape = 2;
+
+ // Size of this prop, in meters.
+ // These are the half-lengths of each dimension on the prop. The number of
+ // dimensions depend on the prop type, e.g. a sphere only requires one number
+ // to specify its radius, whereas a box requires three numbers for each of
+ // the axis.
+ repeated double size = 3 [packed = true];
+
+ // Mass of this prop, in kilograms.
+ double mass = 4;
+
+ // NEXT_ID: 5
+}
+
+message WalkerPose {
+ // Cartesian position of the walker's root frame origin. Must be of length 3.
+ repeated double position = 1 [packed = true];
+
+ // Quaternion orientation of the walker's root frame.
+ // Must be of length 4 and normalized.
+ repeated double quaternion = 2 [packed = true];
+
+ // Joint positions of the walker.
+ repeated double joints = 3 [packed = true];
+
+ // Cartesian position of the walker's center of mass.
+ repeated double center_of_mass = 4 [packed = true];
+
+ // Cartesian position of the walker's end effectors in egocentric coordinates.
+ // Length must be a multiple of 3.
+ repeated double end_effectors = 5 [packed = true];
+
+ // Linear velocity of the walker's root frame origin.
+ // May be approximated by finite differences.
+ repeated double velocity = 6 [packed = true];
+
+ // Angular velocity of the walker's root frame origin.
+ // May be approximated by finite differences.
+ repeated double angular_velocity = 7 [packed = true];
+
+ // Velocity of the walker's joints.
+ // May be approximated by finite differences.
+ repeated double joints_velocity = 8 [packed = true];
+
+ // Cartesian position of the walker's appendages in egocentric coordinates.
+ // Length must be a multiple of 3.
+ repeated double appendages = 9 [packed = true];
+
+ // Cartesian position in global coordinates of the walker's body parts
+ // Length must be a multiple of 3.
+ repeated double body_positions = 10 [packed = true];
+
+ // Orientation of the walker's body parts
+ repeated double body_quaternions = 11 [packed = true];
+
+ // NEXT_ID: 12
+}
+
+message PropPose {
+ // Cartesian position of the prop. Must be of length 3.
+ repeated double position = 1 [packed = true];
+
+ // Quaternion orientation of the prop. Must be of length 4 and normalized.
+ repeated double quaternion = 2 [packed = true];
+
+ // Linear velocity of the prop. May be approximated by finite differences.
+ repeated double velocity = 3 [packed = true];
+
+ // Angular velocity of the prop. May be approximated by finite differences.
+ repeated double angular_velocity = 4 [packed = true];
+
+ // NEXT_ID: 5
+}
+
+message TimestepData {
+ repeated WalkerPose walkers = 1;
+ repeated PropPose props = 2;
+ // NEXT_ID: 3
+}
+
+// A motion-captured sequence that has been fitted to some rigid-body models.
+message FittedTrajectory {
+ // A string that uniquely identifies this motion capture trajectory.
+ string identifier = 1;
+
+ // The date on which this trajectory was captured.
+ int32 year = 2;
+ int32 month = 3;
+ int32 day = 4;
+
+ // The interval (in seconds) between successive timesteps in this trajectory.
+ double dt = 5;
+
+ // Strings identifying each walker model in this trajectory.
+ repeated Walker walkers = 6;
+
+ // Strings identifying each prop in this trajectory.
+ repeated Prop props = 7;
+
+ repeated TimestepData timesteps = 8;
+ // NEXT_ID: 9
+}
diff --git a/dm_control/locomotion/mocap/mocap_pb2.py b/dm_control/locomotion/mocap/mocap_pb2.py
new file mode 100644
index 00000000..64370a88
--- /dev/null
+++ b/dm_control/locomotion/mocap/mocap_pb2.py
@@ -0,0 +1,166 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: mocap.proto
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bmocap.proto\x12\x1b\x64m_control.locomotion.mocap\"T\n\x06Marker\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06parent\x18\x02 \x01(\t\x12\x14\n\x08position\x18\x03 \x03(\x01\x42\x02\x10\x01\x12\x16\n\nquaternion\x18\x04 \x03(\x01\x42\x02\x10\x01\">\n\x07Markers\x12\x33\n\x06marker\x18\x01 \x03(\x0b\x32#.dm_control.locomotion.mocap.Marker\"O\n\x0eSubtreeScaling\x12\x11\n\tbody_name\x18\x01 \x01(\t\x12\x15\n\rparent_length\x18\x02 \x01(\x01\x12\x13\n\x0bsize_factor\x18\x03 \x01(\x01\"M\n\rWalkerScaling\x12<\n\x07subtree\x18\x01 \x03(\x0b\x32+.dm_control.locomotion.mocap.SubtreeScaling\"\xa2\x03\n\x06Walker\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x38\n\x05model\x18\x02 \x01(\x0e\x32).dm_control.locomotion.mocap.Walker.Model\x12;\n\x07scaling\x18\x03 \x01(\x0b\x32*.dm_control.locomotion.mocap.WalkerScaling\x12\x35\n\x07markers\x18\x04 \x01(\x0b\x32$.dm_control.locomotion.mocap.Markers\x12\x0c\n\x04mass\x18\x05 \x01(\x01\x12\x1a\n\x12\x65nd_effector_names\x18\x06 \x03(\t\x12\x17\n\x0f\x61ppendage_names\x18\x07 \x03(\t\"\x98\x01\n\x05Model\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x0c\n\x08\x43MU_2019\x10\x01\x12\x17\n\x13RESERVED_MODEL_ID_2\x10\x02\x12\x17\n\x13RESERVED_MODEL_ID_3\x10\x03\x12\x0c\n\x08\x43MU_2020\x10\x04\x12\x17\n\x13RESERVED_MODEL_ID_5\x10\x05\x12\x17\n\x13RESERVED_MODEL_ID_6\x10\x06\"\x9b\x01\n\x04Prop\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x05shape\x18\x02 \x01(\x0e\x32\'.dm_control.locomotion.mocap.Prop.Shape\x12\x10\n\x04size\x18\x03 \x03(\x01\x42\x02\x10\x01\x12\x0c\n\x04mass\x18\x04 \x01(\x01\"-\n\x05Shape\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\n\n\x06SPHERE\x10\x01\x12\x07\n\x03\x42OX\x10\x02\"\xa8\x02\n\nWalkerPose\x12\x14\n\x08position\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x16\n\nquaternion\x18\x02 \x03(\x01\x42\x02\x10\x01\x12\x12\n\x06joints\x18\x03 \x03(\x01\x42\x02\x10\x01\x12\x1a\n\x0e\x63\x65nter_of_mass\x18\x04 \x03(\x01\x42\x02\x10\x01\x12\x19\n\rend_effectors\x18\x05 \x03(\x01\x42\x02\x10\x01\x12\x14\n\x08velocity\x18\x06 \x03(\x01\x42\x02\x10\x01\x12\x1c\n\x10\x61ngular_velocity\x18\x07 \x03(\x01\x42\x02\x10\x01\x12\x1b\n\x0fjoints_velocity\x18\x08 \x03(\x01\x42\x02\x10\x01\x12\x16\n\nappendages\x18\t \x03(\x01\x42\x02\x10\x01\x12\x1a\n\x0e\x62ody_positions\x18\n \x03(\x01\x42\x02\x10\x01\x12\x1c\n\x10\x62ody_quaternions\x18\x0b \x03(\x01\x42\x02\x10\x01\"l\n\x08PropPose\x12\x14\n\x08position\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x16\n\nquaternion\x18\x02 \x03(\x01\x42\x02\x10\x01\x12\x14\n\x08velocity\x18\x03 \x03(\x01\x42\x02\x10\x01\x12\x1c\n\x10\x61ngular_velocity\x18\x04 \x03(\x01\x42\x02\x10\x01\"~\n\x0cTimestepData\x12\x38\n\x07walkers\x18\x01 \x03(\x0b\x32\'.dm_control.locomotion.mocap.WalkerPose\x12\x34\n\x05props\x18\x02 \x03(\x0b\x32%.dm_control.locomotion.mocap.PropPose\"\x82\x02\n\x10\x46ittedTrajectory\x12\x12\n\nidentifier\x18\x01 \x01(\t\x12\x0c\n\x04year\x18\x02 \x01(\x05\x12\r\n\x05month\x18\x03 \x01(\x05\x12\x0b\n\x03\x64\x61y\x18\x04 \x01(\x05\x12\n\n\x02\x64t\x18\x05 \x01(\x01\x12\x34\n\x07walkers\x18\x06 \x03(\x0b\x32#.dm_control.locomotion.mocap.Walker\x12\x30\n\x05props\x18\x07 \x03(\x0b\x32!.dm_control.locomotion.mocap.Prop\x12<\n\ttimesteps\x18\x08 \x03(\x0b\x32).dm_control.locomotion.mocap.TimestepDatab\x06proto3')
+
+
+
+_MARKER = DESCRIPTOR.message_types_by_name['Marker']
+_MARKERS = DESCRIPTOR.message_types_by_name['Markers']
+_SUBTREESCALING = DESCRIPTOR.message_types_by_name['SubtreeScaling']
+_WALKERSCALING = DESCRIPTOR.message_types_by_name['WalkerScaling']
+_WALKER = DESCRIPTOR.message_types_by_name['Walker']
+_PROP = DESCRIPTOR.message_types_by_name['Prop']
+_WALKERPOSE = DESCRIPTOR.message_types_by_name['WalkerPose']
+_PROPPOSE = DESCRIPTOR.message_types_by_name['PropPose']
+_TIMESTEPDATA = DESCRIPTOR.message_types_by_name['TimestepData']
+_FITTEDTRAJECTORY = DESCRIPTOR.message_types_by_name['FittedTrajectory']
+_WALKER_MODEL = _WALKER.enum_types_by_name['Model']
+_PROP_SHAPE = _PROP.enum_types_by_name['Shape']
+Marker = _reflection.GeneratedProtocolMessageType('Marker', (_message.Message,), {
+ 'DESCRIPTOR' : _MARKER,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.Marker)
+ })
+_sym_db.RegisterMessage(Marker)
+
+Markers = _reflection.GeneratedProtocolMessageType('Markers', (_message.Message,), {
+ 'DESCRIPTOR' : _MARKERS,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.Markers)
+ })
+_sym_db.RegisterMessage(Markers)
+
+SubtreeScaling = _reflection.GeneratedProtocolMessageType('SubtreeScaling', (_message.Message,), {
+ 'DESCRIPTOR' : _SUBTREESCALING,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.SubtreeScaling)
+ })
+_sym_db.RegisterMessage(SubtreeScaling)
+
+WalkerScaling = _reflection.GeneratedProtocolMessageType('WalkerScaling', (_message.Message,), {
+ 'DESCRIPTOR' : _WALKERSCALING,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.WalkerScaling)
+ })
+_sym_db.RegisterMessage(WalkerScaling)
+
+Walker = _reflection.GeneratedProtocolMessageType('Walker', (_message.Message,), {
+ 'DESCRIPTOR' : _WALKER,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.Walker)
+ })
+_sym_db.RegisterMessage(Walker)
+
+Prop = _reflection.GeneratedProtocolMessageType('Prop', (_message.Message,), {
+ 'DESCRIPTOR' : _PROP,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.Prop)
+ })
+_sym_db.RegisterMessage(Prop)
+
+WalkerPose = _reflection.GeneratedProtocolMessageType('WalkerPose', (_message.Message,), {
+ 'DESCRIPTOR' : _WALKERPOSE,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.WalkerPose)
+ })
+_sym_db.RegisterMessage(WalkerPose)
+
+PropPose = _reflection.GeneratedProtocolMessageType('PropPose', (_message.Message,), {
+ 'DESCRIPTOR' : _PROPPOSE,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.PropPose)
+ })
+_sym_db.RegisterMessage(PropPose)
+
+TimestepData = _reflection.GeneratedProtocolMessageType('TimestepData', (_message.Message,), {
+ 'DESCRIPTOR' : _TIMESTEPDATA,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.TimestepData)
+ })
+_sym_db.RegisterMessage(TimestepData)
+
+FittedTrajectory = _reflection.GeneratedProtocolMessageType('FittedTrajectory', (_message.Message,), {
+ 'DESCRIPTOR' : _FITTEDTRAJECTORY,
+ '__module__' : 'mocap_pb2'
+ # @@protoc_insertion_point(class_scope:dm_control.locomotion.mocap.FittedTrajectory)
+ })
+_sym_db.RegisterMessage(FittedTrajectory)
+
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ _MARKER.fields_by_name['position']._options = None
+ _MARKER.fields_by_name['position']._serialized_options = b'\020\001'
+ _MARKER.fields_by_name['quaternion']._options = None
+ _MARKER.fields_by_name['quaternion']._serialized_options = b'\020\001'
+ _PROP.fields_by_name['size']._options = None
+ _PROP.fields_by_name['size']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['position']._options = None
+ _WALKERPOSE.fields_by_name['position']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['quaternion']._options = None
+ _WALKERPOSE.fields_by_name['quaternion']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['joints']._options = None
+ _WALKERPOSE.fields_by_name['joints']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['center_of_mass']._options = None
+ _WALKERPOSE.fields_by_name['center_of_mass']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['end_effectors']._options = None
+ _WALKERPOSE.fields_by_name['end_effectors']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['velocity']._options = None
+ _WALKERPOSE.fields_by_name['velocity']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['angular_velocity']._options = None
+ _WALKERPOSE.fields_by_name['angular_velocity']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['joints_velocity']._options = None
+ _WALKERPOSE.fields_by_name['joints_velocity']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['appendages']._options = None
+ _WALKERPOSE.fields_by_name['appendages']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['body_positions']._options = None
+ _WALKERPOSE.fields_by_name['body_positions']._serialized_options = b'\020\001'
+ _WALKERPOSE.fields_by_name['body_quaternions']._options = None
+ _WALKERPOSE.fields_by_name['body_quaternions']._serialized_options = b'\020\001'
+ _PROPPOSE.fields_by_name['position']._options = None
+ _PROPPOSE.fields_by_name['position']._serialized_options = b'\020\001'
+ _PROPPOSE.fields_by_name['quaternion']._options = None
+ _PROPPOSE.fields_by_name['quaternion']._serialized_options = b'\020\001'
+ _PROPPOSE.fields_by_name['velocity']._options = None
+ _PROPPOSE.fields_by_name['velocity']._serialized_options = b'\020\001'
+ _PROPPOSE.fields_by_name['angular_velocity']._options = None
+ _PROPPOSE.fields_by_name['angular_velocity']._serialized_options = b'\020\001'
+ _MARKER._serialized_start=44
+ _MARKER._serialized_end=128
+ _MARKERS._serialized_start=130
+ _MARKERS._serialized_end=192
+ _SUBTREESCALING._serialized_start=194
+ _SUBTREESCALING._serialized_end=273
+ _WALKERSCALING._serialized_start=275
+ _WALKERSCALING._serialized_end=352
+ _WALKER._serialized_start=355
+ _WALKER._serialized_end=773
+ _WALKER_MODEL._serialized_start=621
+ _WALKER_MODEL._serialized_end=773
+ _PROP._serialized_start=776
+ _PROP._serialized_end=931
+ _PROP_SHAPE._serialized_start=886
+ _PROP_SHAPE._serialized_end=931
+ _WALKERPOSE._serialized_start=934
+ _WALKERPOSE._serialized_end=1230
+ _PROPPOSE._serialized_start=1232
+ _PROPPOSE._serialized_end=1340
+ _TIMESTEPDATA._serialized_start=1342
+ _TIMESTEPDATA._serialized_end=1468
+ _FITTEDTRAJECTORY._serialized_start=1471
+ _FITTEDTRAJECTORY._serialized_end=1729
+# @@protoc_insertion_point(module_scope)
diff --git a/dm_control/locomotion/mocap/props.py b/dm_control/locomotion/mocap/props.py
new file mode 100644
index 00000000..f25ff74f
--- /dev/null
+++ b/dm_control/locomotion/mocap/props.py
@@ -0,0 +1,110 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Props that are constructed from motion-capture data."""
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer.observation import observable
+from dm_control.locomotion.mocap import mocap_pb2
+import numpy as np
+
+_DEFAULT_LIGHT_PROP_RGBA = np.array([0.77, 0.64, 0.21, 1.])
+_DEFAULT_LIGHT_PROP_MASS = 3.
+
+_DEFAULT_HEAVY_PROP_RGBA = np.array([0.77, 0.34, 0.21, 1.])
+_DEFAULT_HEAVY_PROP_MASS = 10.
+
+_PROP_SHAPE = {
+ mocap_pb2.Prop.SPHERE: 'sphere',
+ mocap_pb2.Prop.BOX: 'box',
+}
+
+
+def _default_prop_rgba(prop_mass):
+ normalized_mass = np.clip(
+ (prop_mass - _DEFAULT_LIGHT_PROP_MASS) /
+ (_DEFAULT_HEAVY_PROP_MASS - _DEFAULT_LIGHT_PROP_MASS), 0., 1.)
+ return ((1 - normalized_mass) * _DEFAULT_LIGHT_PROP_RGBA +
+ normalized_mass * _DEFAULT_HEAVY_PROP_RGBA)
+
+
+class Prop(composer.Entity):
+ """A prop that is constructed from motion-capture data."""
+
+ def _build(self, prop_proto, rgba=None, priority_friction=False):
+ rgba = rgba or _default_prop_rgba(prop_proto.mass)
+ self._mjcf_root = mjcf.RootElement(model=str(prop_proto.name))
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', type=_PROP_SHAPE[prop_proto.shape],
+ size=prop_proto.size, mass=prop_proto.mass, rgba=rgba)
+ if priority_friction:
+ self._geom.priority = 1
+ self._geom.condim = 6
+ # Torsional and rolling friction have units of length which correspond
+ # to the scale of the surface contact "patch" that they approximate.
+ self._geom.friction = [.7, prop_proto.size[0]/4, prop_proto.size[0]/2]
+
+ self._body_geom_ids = ()
+ self._position = self._mjcf_root.sensor.add(
+ 'framepos', name='position', objtype='geom', objname=self.geom)
+
+ self._orientation = self._mjcf_root.sensor.add(
+ 'framequat', name='orientation', objtype='geom', objname=self.geom)
+
+ def _build_observables(self):
+ return Observables(self)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ def update_with_new_prop(self, prop):
+ self._geom.size = prop.geom.size
+ self._geom.mass = prop.geom.mass
+ self._geom.rgba = prop.geom.rgba
+
+ @property
+ def geom(self):
+ return self._geom
+
+ def after_compile(self, physics, random_state):
+ del random_state # unused
+ self._body_geom_ids = (physics.bind(self._geom).element_id,)
+
+ @property
+ def body_geom_ids(self):
+ return self._body_geom_ids
+
+ @property
+ def position(self):
+ """Ground truth pos sensor."""
+ return self._position
+
+ @property
+ def orientation(self):
+ """Ground truth orientation sensor."""
+ return self._orientation
+
+
+class Observables(composer.Observables):
+
+ @define.observable
+ def position(self):
+ return observable.MJCFFeature('sensordata', self._entity.position)
+
+ @define.observable
+ def orientation(self):
+ return observable.MJCFFeature('sensordata', self._entity.orientation)
diff --git a/dm_control/locomotion/mocap/test_001.textproto b/dm_control/locomotion/mocap/test_001.textproto
new file mode 100644
index 00000000..a1323eb5
--- /dev/null
+++ b/dm_control/locomotion/mocap/test_001.textproto
@@ -0,0 +1,247 @@
+identifier: "cmuv2019_001"
+year: 2020
+month: 7
+day: 7
+dt: 0.05
+walkers {
+ name: "cmuv2019_CMU"
+ model: CMU_2019
+ markers {
+ marker {
+ name: "_left_shoulder"
+ parent: "lhumerus"
+ }
+ marker {
+ name: "_left_elbow"
+ parent: "lradius"
+ }
+ marker {
+ name: "_left_wrist"
+ parent: "lhand"
+ }
+ marker {
+ name: "_left_hip"
+ parent: "lfemur"
+ }
+ marker {
+ name: "_left_knee"
+ parent: "ltibia"
+ }
+ }
+}
+props {
+ name: "cmuv2019_box"
+ shape: BOX
+ size: [0.1775, 0.1275, 0.1775]
+ mass: 3.0
+}
+timesteps {
+ walkers {
+ position: [0.0, 0.0, 0.94]
+ quaternion: [0.460752975375259, 0.5363829748256799, 0.5363829748256799, 0.460752975375259]
+ joints: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.4595694075999999, 0.4172453, 0.00827831100000001, 0.46164259239999994, 0.4172453, 0.00827831100000001, -0.3791909916481448, -0.8535581408327323, 0.049446100000000104, 0.3791909973295273, -0.8535581387648811, 0.04944610000000012]
+ velocity: [0.0, 0.0, 0.0]
+ angular_velocity: [0.0, 0.0, 0.0]
+ joints_velocity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
+ appendages: [-0.4595694075999999, 0.4172453, 0.00827831100000001, 0.46164259239999994, 0.4172453, 0.00827831100000001, -0.3791909916481448, -0.8535581408327323, 0.049446100000000104, 0.3791909973295273, -0.8535581387648811, 0.04944610000000012, -0.0012630876000000006, 0.5211893, 0.0037585310000000233]
+ body_positions: [0.0, 0.0, 0.94, 0.03503343204707343, 0.10193700000000004, 0.841785061528436, -0.022359767062621314, 0.24043608588488902, 0.4656142533441546, -0.07985898112408829, 0.37919099732952743, 0.08874859401537516, 0.02206814054904281, 0.3869807066780439, 0.05154761775054591, 0.0, 0.0, 0.94, 0.035033432047073425, -0.10193700000000004, 0.841785061528436, -0.022359767356247333, -0.24043608053615773, 0.465614251419649, -0.07985898143597565, -0.3791909916481449, 0.08874859197117962, 0.022068139621451835, -0.3869806897809207, 0.05154761167085957, 0.0, 0.0, 0.94, 0.009114696982704808, 0.0005658620000000002, 1.0531324717081434, 0.027168690390768996, 0.0010783900000000003, 1.1652520710456704, 0.048981475682166986, 0.0010365924000000002, 1.2771876386109813, 0.07319323870385834, -0.002264837599999999, 1.3650207489466142, 0.08232483949401602, -0.0012630875999999984, 1.4546601148253597, 0.048981475682166986, 0.0010365924000000002, 1.2771876386109813, 0.07111537503931514, 0.1847995924, 1.3512235088587716, 0.07111537503931514, 0.46164259240000005, 1.3512235088587716, 0.07111537503931514, 0.6431785924000001, 1.3512235088587716, 0.07111537503931514, 0.7339461924000001, 1.3512235088587716, 0.07111537503931514, 0.7691405924000001, 1.3512235088587716, 0.08100097861124037, 0.7339461924000001, 1.349715250917059, 0.048981475682166986, 0.0010365924000000002, 1.2771876386109813, 0.07111537503931514, -0.18272640759999997, 1.3512235088587716, 0.07111537503931514, -0.45956940760000003, 1.3512235088587716, 0.07111537503931514, -0.6411054076, 1.3512235088587716, 0.07111537503931514, -0.7318730076, 1.3512235088587716, 0.07111537503931514, -0.7670674075999999, 1.3512235088587716, 0.08100097861124037, -0.7318730076, 1.349715250917059]
+ body_quaternions: [0.4607529753752591, 0.53638297482568, 0.53638297482568, 0.4607529753752591, 0.37374430412809334, 0.6213759435143635, 0.43509232145296073, 0.5337619354788383, 0.37374430412809334, 0.6213759435143635, 0.43509232145296073, 0.5337619354788383, 0.7036562752040176, 0.17510201144637338, -0.06976995317475732, 0.6850834150579287, 0.7036562752040176, 0.17510201144637338, -0.06976995317475732, 0.6850834150579287, 0.4607529753752591, 0.53638297482568, 0.53638297482568, 0.4607529753752591, 0.5337619354788383, 0.43509232145296073, 0.6213759435143635, 0.37374430412809334, 0.5337619354788383, 0.43509232145296073, 0.6213759435143635, 0.37374430412809334, 0.6850834150579287, -0.06976995317475732, 0.17510201144637338, 0.7036562752040176, 0.6850834150579287, -0.06976995317475732, 0.17510201144637338, 0.7036562752040176, 0.4607529753752591, 0.53638297482568, 0.53638297482568, 0.4607529753752591, 0.460752975375259, 0.5363829748256799, 0.5363829748256799, 0.460752975375259, 0.4607529753752591, 0.53638297482568, 0.53638297482568, 0.4607529753752591, 0.460752975375259, 0.5363829748256799, 0.5363829748256799, 0.460752975375259, 0.4607529753752591, 0.53638297482568, 0.53638297482568, 0.4607529753752591, 0.460752975375259, 0.5363829748256799, 0.5363829748256799, 0.460752975375259, 0.460752975375259, 0.5363829748256799, 0.5363829748256799, 0.460752975375259, -1.387778780781446e-17, 0.8257302323277214, 0.0, -0.5640652297562824, -1.387778780781446e-17, 0.8257302323277214, 0.0, -0.5640652297562824, 6.938896330352964e-18, 0.07562950088251945, -1.201851538898785e-17, 0.9971359880158077, 6.938896330352964e-18, 0.07562950088251945, -1.201851538898785e-17, 0.9971359880158077, 6.938896330352964e-18, 0.07562950088251945, -1.201851538898785e-17, 0.9971359880158077, 0.38158688963013887, 0.06987256465818399, 0.02894211657475843, 0.921233751150411, 0.460752975375259, 0.5363829748256799, 0.5363829748256799, 0.460752975375259, -0.5640652297562824, 0.0, 0.8257302323277214, 1.387778780781446e-17, -0.5640652297562824, 0.0, 0.8257302323277214, 1.387778780781446e-17, 0.9971359880158077, 1.201851538898785e-17, 0.07562950088251945, -6.938896330352964e-18, 0.9971359880158077, 1.201851538898785e-17, 0.07562950088251945, -6.938896330352964e-18, 0.9971359880158077, 1.201851538898785e-17, 0.07562950088251945, -6.938896330352964e-18, 0.921233751150411, 0.02894211657475845, 0.06987256465818399, 0.38158688963013887]
+ }
+ props {
+ position: [1.0, 0.0, 3.0]
+ quaternion: [0.5, 0.5, 0.5, 0.5]
+ velocity: [0.0, 0.0, 0.0]
+ angular_velocity: [0.01, 0.01, 0.01]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.06774922904985167, 8.62317917932669e-06, 0.9221663814038061]
+ quaternion: [0.34965798554603905, 0.6145544383194512, 0.6145408899535888, 0.34985944329230406]
+ joints: [0.07307174594306007, 0.09137216490576963, -1.0635390156428057, 1.1439872254968426, -0.2427867720496015, 0.14608977119710087, -0.48311190169124024, -0.07377312166316363, -0.09156196080113051, -1.0633751318220606, 1.144180726408987, 0.2428658967623869, 0.1461727560244489, -0.4830965949955947, 0.00044566584528417673, -7.502971352540864e-05, -0.03759325808518402, -0.00018197325930370604, -9.454844871630644e-05, -0.11292380266293534, -0.0006363603995869877, -0.00022046153178158413, -0.06763273331212036, -0.0020438055516730953, -0.001788650676703215, -0.11334419288358563, 0.0022718115660939607, 0.0009697136799126119, -0.024448603402237083, 0.0006321955823730347, 0.00022336780489048706, 0.03510977034659875, 0.004231622042487199, 0.10610710564621416, -0.05234815425996891, 0.0288465580072744, -0.009510687604812707, 1.432868496510371, 0.795210756377402, 0.007592523562062222, 0.0013253161842978276, 0.48787255070911767, 0.0010474850518309058, 0.4876968995369291, -0.004276841719898402, -0.10592101794082341, 0.051633563032365506, -0.028719259598600666, -0.009542955319520506, 1.4329433137289531, -0.795210490041185, -0.007590405731788603, 0.0013297958018352837, 0.4878737645870649, -0.001046968878140477, 0.48769739258747735]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.4563459945600351, 0.4147600038720176, -0.11430374863996037, 0.45844934217199473, 0.41457251379982074, -0.1140596780134295, -0.37539588818761316, -0.6279551938590521, 0.36927629711185916, 0.37496014304294756, -0.6281228589604666, 0.36945788763315346]
+ velocity: [-2.0741340390160774, -0.0007726576435444993, -1.1578987712566824]
+ angular_velocity: [-4.601007706465458, -0.00990462010442595, 0.023388009388295074]
+ joints_velocity: [3.085238248200053, -1.750134235940812, -10.352677620382744, 23.792490352058444, -5.760812901871518, 9.251341461081932, -5.289551464826017, -3.1321143752914598, 1.7561072017545585, -10.374327416018621, 23.80758181245679, 5.758817383304328, 9.253432744861898, -5.289086315098841, -0.008754897878163529, -0.018191272988579994, 6.773988366317279, -0.0034661163723865114, 0.0021132577879451396, 4.145828070365569, -0.002615574686705596, 0.012506229183704809, 1.477732745026705, -0.059994860047335494, -0.011752106947118372, -1.8004638072693577, 0.032373130186844054, 0.014677679719965545, -0.925003442302155, 0.012255704707537072, 0.0023617398542758452, 0.12787365073433424, 1.2598731052356564, 0.017062782666644516, -4.347934598698709, 2.2788708363209147, 5.102484729821626, 18.11281490461815, 18.593778018364073, -0.20986706994619514, -0.1556438615844056, 8.425210880841203, -0.010592030415427498, 8.448212935616336, -1.2643372991069797, -0.011965485316861924, 4.348337876956638, -2.2731206589024127, 5.125385824380853, 18.111330703658762, -18.593775367435974, 0.20984853793410896, -0.15542666149662246, 8.425256345717887, 0.010586978140537169, 8.448230644612204]
+ appendages: [-0.4563459945600351, 0.4147600038720176, -0.11430374863996037, 0.45844934217199473, 0.41457251379982074, -0.1140596780134295, -0.37539588818761316, -0.6279551938590521, 0.36927629711185916, 0.37496014304294756, -0.6281228589604666, 0.36945788763315346, -0.001023645686708486, 0.5093276371211684, -0.10341591635224448]
+ body_positions: [-0.06774922904985167, 8.62317917932669e-06, 0.9221663814038061, -0.07213322278117704, 0.1019718603853017, 0.8180081060418277, 0.14602280126335385, 0.21213958479390865, 0.4951281365535957, -0.07091255643790509, 0.3751557079538947, 0.1935462995376266, -0.011812245850253991, 0.3796646341008824, 0.10232923308181356, -0.06774922904985167, 8.62317917932669e-06, 0.9221663814038061, -0.07210788716113353, -0.10190213130300368, 0.8179556924639897, 0.14606846259774695, -0.21204828322071526, 0.49508209699879446, -0.07088981326521829, -0.375200362348245, 0.19359027976332566, -0.0118202138790027, -0.3797074282484065, 0.1023532306841076, -0.06774922904985167, 8.62317917932669e-06, 0.9221663814038061, -0.020706677827992598, 0.0005044145960813131, 1.025457749005814, 0.02285766927629017, 0.000968176303590253, 1.1303336113621878, 0.06291108958126879, 0.0009506474214829288, 1.2371095161815209, 0.09174405964962656, -0.002145143944095809, 1.3235435377977391, 0.10343770057167725, -0.001119906780878532, 1.4128845317587546, 0.06291108958126879, 0.0009506474214829288, 1.2371095161815209, 0.07872518687546184, 0.18450833612564424, 1.3132445869272424, 0.04584026872998435, 0.45836545878996193, 1.3369702096500737, 0.2167993613255886, 0.5002792357599278, 1.3813722732845441, 0.30227853092875695, 0.5212360318912685, 1.4035732072653972, 0.33538617493324063, 0.5296112757065105, 1.4120811441990861, 0.3041627990987092, 0.5116444169756863, 1.4056828178396228, 0.06291108958126879, 0.0009506474214829288, 1.2371095161815209, 0.078712960700794, -0.18256725736831744, 1.3133429684435278, 0.045839852313722486, -0.45642992831218865, 1.3370208665229766, 0.2168065727526068, -0.49833213298775986, 1.3814044788331403, 0.3022895562606079, -0.5192831429974019, 1.403596187192496, 0.33539875831227467, -0.5276560351725276, 1.4121003748365701, 0.30417288867312703, -0.5096916619748965, 1.405707241625445]
+ body_quaternions: [0.34965798554603905, 0.6145544383194512, 0.6145408899535888, 0.34985944329230406, 0.5726621040644053, 0.48678008851225224, 0.18981567805061383, 0.6317224614992868, 0.2180091088147162, 0.7192835105000539, 0.501558786464931, 0.4284180711832285, 0.7103327928477221, 0.39848584864718817, 0.13578987699787573, 0.5640899406459817, 0.7850329333864567, 0.2169952572421104, -0.0030905553358065163, 0.5801954845574048, 0.34965798554603905, 0.6145544383194512, 0.6145408899535888, 0.34985944329230406, 0.6316115586867354, 0.1896777146106577, 0.4868966337042786, 0.5727310639391163, 0.4283509913859179, 0.501424196330124, 0.719439921682768, 0.21793439976856172, 0.5639080678464573, 0.13578642466951799, 0.398683739556389, 0.7103668163044202, 0.580018089565342, -0.0030459618885442805, 0.21718527263594087, 0.7851116450818934, 0.36109499971498016, 0.6080249298058993, 0.6077054620396355, 0.3614002729973552, 0.39489094038357336, 0.5866371036684519, 0.5863842087838459, 0.3950590000131756, 0.4146847179960228, 0.5728010515869348, 0.5728414807057759, 0.41447337422839964, 0.44738571489910156, 0.5481246913056548, 0.5487115629368603, 0.4453324214389611, 0.4532856507531888, 0.5430309994409185, 0.5428110295692035, 0.45277548370630316, 0.44347760696314553, 0.551021668256789, 0.5505579306387048, 0.4433832408251061, 0.3829119275227035, 0.5511787038740584, 0.5927675999216485, 0.4452045200126448, 0.0032225265534728802, 0.7823922864763928, 0.07336444732359353, -0.6184412528718942, -0.5113665304409908, 0.5921590547816429, -0.3508022502942293, -0.5145772114788599, 0.6194036355621222, -0.16030645482568087, -0.029951352543374832, 0.7679478454130008, 0.6165900158032838, -0.1600104493903233, -0.02883212513928294, 0.7703129994084, 0.6369821387957654, -0.006351878270195012, 0.15807070909961143, 0.7544713774705832, 0.8716643336547161, 0.07585500554215457, 0.029764630783863683, 0.48328187874446493, 0.44537599255454896, 0.5926668421515962, 0.5512616582362491, 0.3827490347848397, -0.6184711596904625, 0.07329359224559324, 0.7823750497346222, 0.003279560888063653, -0.5145401091414736, -0.35089457407861024, 0.5921643821955331, -0.5113343509151851, 0.7679215433829385, -0.02985845061430381, -0.16033303672263793, 0.6194338491482172, 0.7702861622881302, -0.028737585378396106, -0.16003524569620597, 0.616621520471836, 0.7544224052349193, 0.15815642606456237, -0.006367944703963319, 0.6370187035356899, 0.48322710815269176, 0.02983134802097269, 0.07580252465049683, 0.8716969828319066]
+ }
+ props {
+ position: [1.0, -8.131516293641283e-23, 2.98651125]
+ quaternion: [0.4996249531377709, 0.5001249531167505, 0.5001249750738, 0.5001249311716818]
+ velocity: [2.168404344971009e-21, 1.8070036208091737e-21, -0.4904999999999999]
+ angular_velocity: [0.010001596326029406, 0.01, 0.009998403444594522]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.13723881550104733, -1.3222658979455727e-05, 0.8452208830091771]
+ quaternion: [0.4643381834881084, 0.5334183548179003, 0.5330957734461546, 0.4643961740068223]
+ joints: [0.19995585088543028, -0.030416185500267402, -1.2426479687348988, 1.9161311455440067, -0.357363355806902, 0.34409372635887747, -0.5820304109484147, -0.20061012540521198, 0.030619555328625384, -1.243222512648805, 1.916312452535122, 0.3572272612608726, 0.3441032338953631, -0.58201897491688, -3.7725485066861605e-06, -0.00011415392212689341, 0.18722474430442732, -0.00012235855338922765, -1.4606180524132012e-05, 0.14135891298901201, -0.0002471226156018929, 0.0001497046838774384, 0.11680590128227201, -0.002639320432075611, -0.00010786383448640731, 0.04267259092130023, 0.0010484209326550688, 0.00035878845820517646, 0.061940760302219885, 0.00024185408196590418, 7.388699685979198e-05, 0.08437977123174181, 0.10578800611731765, -0.008565500568131833, -0.20172542299208276, 0.09150190199681921, 0.388080227104828, 1.4145503104261044, 1.3757007772133516, -0.0020542074151783, -0.0009275854039035471, 0.7162921346933948, 0.0001658969516671765, 0.7161788464744924, -0.10590282128007186, 0.008436801901822777, 0.2022000818458312, -0.09174160225975268, 0.388058086404679, 1.4145107780083068, -1.3757009494694494, 0.0020568320646572624, -0.0009245077817185168, 0.7162929545939353, -0.00016566663448737475, 0.7161790600796231]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.40951064458681913, 0.49287610860153597, 0.26847142961705583, 0.4118813705505828, 0.49241393489522933, 0.2684268428516736, -0.3323092115698259, -0.47547592374141967, 0.1798862373521949, 0.3321489841549144, -0.4757125890644812, 0.1796874997543243]
+ velocity: [-0.7478660387446584, -0.0002181121913792564, -1.5990815583355693]
+ angular_velocity: [-6.75394501200118, 0.0019498095369661495, -0.005825779045285072]
+ joints_velocity: [1.7290850565887572, -1.5213262605745332, -2.152808697271129, 9.012462890306992, -0.514312712638993, 1.0214174148241193, -0.696276437961935, -1.698073810789286, 1.515588336235765, -2.152744278061892, 9.007272797591204, 0.5147268400586045, 1.0188818999134257, -0.6967782663327147, -0.003853333031339787, 0.003879293252467237, 2.2896560694152597, 0.004838220343049543, 0.0010300698738088677, 3.489220539339997, 0.011163762577984343, 0.0022162586365436655, 3.0029384929250855, 0.022863134310065956, 0.050111245476063576, 5.711310674558422, -0.03652331520156412, -0.025619078232111646, 3.511027094308463, -0.011822749342340252, -0.005523948306568948, 1.7320953650872684, 1.6489291776085762, -3.1237138543181033, 0.45901917473220377, 0.05983975006207459, 7.7407985019698815, -5.53486280768769, 6.282104697263797, -0.02025546734585853, -0.010762388783781841, 2.2544571834567217, -0.004999645348959707, 2.254427010826523, -1.648743734107023, 3.1193365446491645, -0.44632503218602826, -0.06944073580478716, 7.72374276143117, -5.534752011714334, -6.282099596234206, 0.020274254610376116, -0.010880611065493994, 2.254441790924957, 0.005001927218850058, 2.254418008286578]
+ appendages: [-0.40951064458681913, 0.49287610860153597, 0.26847142961705583, 0.4118813705505828, 0.49241393489522933, 0.2684268428516736, -0.3323092115698259, -0.47547592374141967, 0.1798862373521949, 0.3321489841549144, -0.4757125890644812, 0.1796874997543243, -0.0007791277140850827, 0.47470757563115634, 0.19908239564151226]
+ body_positions: [-0.13723881550104733, -1.3222658979455727e-05, 0.8452208830091771, -0.10085160786040216, 0.10194854058288313, 0.7475252463323175, 0.26254980322020666, 0.16006649080558968, 0.5785809616952109, -0.024548298758613085, 0.3322822899062934, 0.3494451752512844, -0.007862951515434684, 0.35548584352791746, 0.24448363645700508, -0.13723881550104733, -1.3222658979455727e-05, 0.8452208830091771, -0.10091076797762893, -0.10192543751663716, 0.7474515580912209, 0.26255402695336155, -0.16002556415919525, 0.5786375445884808, -0.024511729646446656, -0.3321759758581043, 0.3494121153955412, -0.007801742554840171, -0.3553786806746207, 0.24445430882568187, -0.13723881550104733, -1.3222658979455727e-05, 0.8452208830091771, -0.1087076076099176, 0.0005035772945685112, 0.9550755922835845, -0.056798371548950136, 0.0009741475712263213, 1.0560816499916166, 0.009864883970342811, 0.0009172153836991861, 1.1486093634491086, 0.0715392123049652, -0.0021611278372367284, 1.2156807076087586, 0.12520342054274686, -0.0010285650229683693, 1.2880580217816446, 0.009864883970342811, 0.0009172153836991861, 1.1486093634491086, 0.07273803637338581, 0.1754610803814716, 1.2215577400701352, 0.19644320043486863, 0.411608357360038, 1.2962139026364627, 0.36925863776574075, 0.35649798351075257, 1.2889344376608316, 0.4556659756462364, 0.3289429180173937, 1.2852947212127326, 0.4891590080971284, 0.31821660081360115, 1.2839498465542176, 0.45329232984217455, 0.3208563084490817, 1.2906773368674624, 0.009864883970342811, 0.0009172153836991861, 1.1486093634491086, 0.07268313225478026, -0.17364384932908078, 1.2215638843639596, 0.19631253950420743, -0.4097837640391333, 1.2963686769436091, 0.36914499572605597, -0.3547237079942563, 1.2891125346725532, 0.45556084301454014, -0.32719380129223036, 1.2854844795253522, 0.48905712416721475, -0.31647707511862644, 1.2841440615230229, 0.4531878106359998, -0.3191039638191227, 1.2908625132728386]
+ body_quaternions: [0.4643381834881084, 0.5334183548179003, 0.5330957734461546, 0.4643961740068223, 0.6536192877138164, 0.35187399691940674, -0.0319596461175699, 0.6692870072252146, 0.08803749954841115, 0.7370772562308641, 0.5291506120063212, 0.4110549194861684, 0.6073166614922566, 0.5488672579222622, 0.2555806313792272, 0.514382879495996, 0.7392645460261854, 0.35153481354469185, 0.09724500124972948, 0.5660871095394567, 0.4643381834881084, 0.5334183548179003, 0.5330957734461546, 0.4643961740068223, 0.6693648487724843, -0.032028161037359996, 0.351694619762155, 0.6536327642940704, 0.4111077639226187, 0.5292121600210923, 0.736993112599055, 0.08812518431136013, 0.5145138127007334, 0.2555601088783138, 0.5487587100323837, 0.6073124776042835, 0.5662060932595682, 0.09719101123806753, 0.3514362571402774, 0.739227383460253, 0.41247223127869465, 0.5745154691191495, 0.5741430926109358, 0.41250253741514, 0.3709013727811242, 0.6021801361253388, 0.6018707616703803, 0.37089465023300383, 0.33512661837726576, 0.622700946609304, 0.6225955688067318, 0.3351245118595735, 0.3222578440675861, 0.6289152535386919, 0.6303964054860298, 0.3212722486514631, 0.3023388730759857, 0.6388563654045908, 0.6397792915661179, 0.30188774120910494, 0.2750616712491991, 0.6511033424033731, 0.6518793122067892, 0.2746977918323278, 0.31945889265984345, 0.6562493986330435, 0.5874402699580009, 0.34956640560269286, -0.26378163614180183, 0.6741985459175073, -0.06260104527608189, -0.6869910321448273, -0.6385802824600048, 0.3410859917033059, -0.49396794190458154, -0.48153010444584365, 0.7715022741422121, -0.14996338275526347, 0.2378515204272243, 0.5707204911832507, 0.7720183080001113, -0.15056567442726754, 0.23743301314897758, 0.5700379584412821, 0.7758115113897834, 0.1296101423430767, 0.42218775987860235, 0.45063866368255073, 0.9528928356857052, 0.11106761956637048, 0.23256124391236693, 0.15992027830343505, 0.3496290979591263, 0.5875341387674384, 0.6561611365191383, 0.31939895517206196, -0.6869432912166747, -0.062319640596469744, 0.6742373286812904, -0.26387345766682796, -0.48168642071728784, -0.4937134949018898, 0.34106843356246375, -0.6386685374085503, 0.5708606355563255, 0.237582356250253, -0.1498920246324431, 0.7714953918996491, 0.5701766067278953, 0.23716453827088957, -0.15049316071360327, 0.772012582478711, 0.45086245229943955, 0.42198510658495614, 0.1296763659544909, 0.7757806771503502, 0.16011940011423434, 0.2324005345574235, 0.11124340267787497, 0.9528781006010841]
+ }
+ props {
+ position: [1.0, -3.3790967709131553e-22, 2.9484974999999998]
+ quaternion: [0.49924981260223683, 0.5002498124337476, 0.5002498962951576, 0.5002497286689082]
+ velocity: [-1.2287624621502383e-20, 1.807003620809175e-21, -0.9810000000000002]
+ angular_velocity: [0.010003192397159901, 0.01, 0.00999680663436337]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.1531883905119414, -2.9259929453824055e-05, 0.7741990896473868]
+ quaternion: [0.520213356161675, 0.47901428056125067, 0.4789269647794053, 0.5201464654273641]
+ joints: [0.2028640569326736, -0.02375885559296947, -1.3126325595047423, 2.0279587985216043, -0.3629317891104318, 0.3578291249238457, -0.5973530383586724, -0.2019976250312827, 0.023731115272901226, -1.3129035970497456, 2.0279013330471876, 0.3629766572693475, 0.3577834663700596, -0.5973593354000556, -0.00026599560136988796, -2.9942141080754005e-05, 0.2554498913250937, 4.865584724259407e-05, -3.3365322074479035e-05, 0.24333629146306704, 0.0002764623660661184, 0.0001328975082606084, 0.20461696746623595, -0.00037261100285889694, 0.0015299382326677405, 0.28938092829580114, -0.0003679242842965982, -0.000727046160015426, 0.21969719355152634, -0.00015054247992474712, -0.00018879686699335832, 0.162574180196977, 0.10063584774829659, -0.11585180857321656, -0.0275627489961284, 0.015488567087991578, 0.5870021618934625, 1.4100803836510483, 1.5266061189154547, -0.001118033665621531, -0.0008498591000338199, 0.7686980093551765, 0.0001815445822710972, 0.7686918149528452, -0.10060573135788205, 0.11576531120972167, 0.027703238187131187, -0.015660212920236588, 0.5867121379905871, 1.410021733279601, -1.526607542516006, 0.0011235890071457722, -0.0008541837115066165, 0.7686976316676686, -0.00018101350784199504, 0.768691342103616]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.3761327036485027, 0.32668301477673284, 0.42414600315720347, 0.3781475149572606, 0.32665760030549895, 0.4242944219105505, -0.3138699283220329, -0.44234693697861543, 0.1749853286591507, 0.3141830300819743, -0.4421793498963694, 0.17485665747029264]
+ velocity: [-0.11525057355727727, -0.00022636255271542508, -1.2835805117943375]
+ angular_velocity: [-2.711493505617659, 8.999511411179363e-05, -0.006760593389952753]
+ joints_velocity: [-1.364906400878748, 1.238532828838383, -0.5738299987003761, -2.7979199130134345, -0.07922066622647758, 0.010372528337368923, -0.13481744071994714, 1.3878043105352547, -1.2350715690834952, -0.5688163316964466, -2.8030180596135743, 0.083261766794857, 0.00985587111696367, -0.13503047966280657, -0.0040398977685366716, 0.0016401170131495003, 0.8131959923361487, 0.0016500397291427985, -0.0001724494984062483, 1.2523957335446931, 0.006956133348950269, -0.0009335193650836773, 1.107657472789275, 0.049098012947404734, 0.01738223711129573, 3.849090852669013, -0.016614868674841212, -0.016986365418334853, 2.4740024613385674, -0.0036761608867629915, -0.004584250583777999, 1.2051621538184683, -1.305257179822639, -1.0487657551810372, 4.216789892384123, -3.006882126916064, 0.5242326834745742, 2.452140800616469, 1.344989097953823, 0.002706287568225472, 0.050477927389480086, 0.46243371905241226, -0.000491602402823284, 0.45739525087352856, 1.3088959522690737, 1.0506099482327005, -4.229870946542185, 3.017416196426645, 0.5247953657856869, 2.450601133117795, -1.3449998086299033, -0.0027343118135425984, 0.05039558162682622, 0.4624193732763327, 0.0004912673803315129, 0.45739011174728506]
+ appendages: [-0.3761327036485027, 0.32668301477673284, 0.42414600315720347, 0.3781475149572606, 0.32665760030549895, 0.4242944219105505, -0.3138699283220329, -0.44234693697861543, 0.1749853286591507, 0.3141830300819743, -0.4421793498963694, 0.17485665747029264, -0.0010508296567865096, 0.37097045589838534, 0.3175552065324581]
+ body_positions: [-0.1531883905119414, -2.9259929453824055e-05, 0.7741990896473868, -0.09633440243634744, 0.10190136895008117, 0.6867780486727294, 0.3014091219793111, 0.14812958042075886, 0.6264166845840752, 0.05753535867516596, 0.3141328919704526, 0.34792916533874046, 0.09140253680064381, 0.33855557858643875, 0.24747879468917788, -0.1531883905119414, -2.9259929453824055e-05, 0.7741990896473868, -0.09636564282841331, -0.10197262858336636, 0.686772592507466, 0.30138729123367475, -0.1481921045144736, 0.6264665686403136, 0.0575811544434347, -0.31392007641339814, 0.34775593511030345, 0.09147922019161225, -0.3382485593871733, 0.2472931241099785, -0.1531883905119414, -2.9259929453824055e-05, 0.7741990896473868, -0.14163184410615864, 0.0005608644373826922, 0.8871081328816915, -0.09484902080092238, 0.0010819153404738424, 0.9905881303115536, -0.024747238924720943, 0.0010133831722549554, 1.0805387062769118, 0.05326647979890475, -0.0022254575768118874, 1.127603914036021, 0.13273921819963636, -0.0011344975764013443, 1.1700597641460269, -0.024747238924720943, 0.0010133831722549554, 1.0805387062769118, 0.05612215415510642, 0.17367474723744183, 1.138747420187712, 0.242823108203629, 0.37804792943357635, 1.134697382484546, 0.3939241182407955, 0.2790119810818799, 1.1169298560436416, 0.4694742903204354, 0.22949422512379203, 1.1080461319725072, 0.49877445529741476, 0.21029484269773752, 1.104650570437223, 0.46720187873220126, 0.22454318655972805, 1.116432034590044, -0.024747238924720943, 0.0010133831722549554, 1.0805387062769118, 0.056024858557679205, -0.1717049325537082, 1.138713541336241, 0.24255751945768783, -0.37623225741111815, 1.1346903021892445, 0.39375006483780695, -0.27732911786000874, 1.1169617234415161, 0.4693460043872324, -0.22787776600958087, 1.1080974731311517, 0.4986639364338293, -0.20870414590762165, 1.1047097092584957, 0.4670761104716261, -0.22292526192835527, 1.1164831922379033]
+ body_quaternions: [0.520213356161675, 0.47901428056125067, 0.4789269647794053, 0.5201464654273641, 0.6528429305955671, 0.25321222158664486, -0.12929185534613227, 0.7021134487756964, 0.13005637130023787, 0.688044810658192, 0.5277243108978555, 0.4808188125457771, 0.6185543863540375, 0.5001557934654894, 0.2084214151761894, 0.5690300230049741, 0.7383426167304535, 0.29599920090186005, 0.03175366672291738, 0.6051663887186207, 0.520213356161675, 0.47901428056125067, 0.4789269647794053, 0.5201464654273641, 0.7022415695487784, -0.12903639759792462, 0.25283536555325214, 0.6529017261565266, 0.48068482371743637, 0.527954272788012, 0.6878918140344938, 0.13042713780863258, 0.5690230404837747, 0.20860431175060157, 0.49976672335833344, 0.6188135767267339, 0.6052136342044542, 0.0319286149931195, 0.2955487628649023, 0.7384767764030796, 0.4550428163706264, 0.541330328522093, 0.5413266699365336, 0.45482188503531223, 0.3859741177962138, 0.5925766543425764, 0.5925054923188379, 0.385764346060655, 0.32333800989902484, 0.6289441735003107, 0.6287589512925286, 0.3233325519400559, 0.2289103275332529, 0.6685714203427953, 0.6692300797222736, 0.22965935261677417, 0.15452478212053175, 0.6896285296680238, 0.6903816960268048, 0.154621138818968, 0.09809795713246336, 0.6998674183664378, 0.7006879001439398, 0.09797374192818922, 0.34067545913896413, 0.6783027118924935, 0.5775751616374788, 0.30042069708857494, -0.2632823896294485, 0.6245332427845731, -0.2481300331737685, -0.6921503439319938, -0.6052498377695754, 0.30501151111377733, -0.6375302614795028, -0.36632741866514673, 0.7595926495142511, -0.19019558578099496, 0.4424830021889437, 0.43709660113892873, 0.7597558817545771, -0.190765771434589, 0.4421910157919863, 0.43685984725211674, 0.7758548781971027, 0.10802679014701683, 0.5737330603075714, 0.2391856937577998, 0.9352072019530152, 0.005636718426254904, 0.35342564260219944, -0.02111946902310063, 0.3004612726288626, 0.5777938246844372, 0.6781054056738444, 0.3406616776255955, -0.6921839948032514, -0.24786588607251894, 0.6247188014834989, -0.263002355374849, -0.3665429212999558, -0.6373401501437023, 0.30535205637646157, -0.6051478674889006, 0.4373987274756935, 0.4423296026960646, -0.190462449556555, 0.759441196598066, 0.43716083836641767, 0.44203586462916944, -0.19103535647073147, 0.7596052845777299, 0.23952300853551123, 0.5737020416573257, 0.10772025952058253, 0.7758163709726238, -0.02071592671066637, 0.3533228442323849, 0.005451562549877831, 0.9352561673626221]
+ }
+ props {
+ position: [1.0, -8.836247705756862e-22, 2.88595875]
+ quaternion: [0.4988745784701102, 0.5003745779012548, 0.5003747636506126, 0.500374392478277]
+ velocity: [1.0842021724855045e-20, -9.75781955236954e-21, -1.4715000000000005]
+ angular_velocity: [0.010004788213350803, 0.01, 0.00999520956934724]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.15569106800425475, -2.5349889254349086e-05, 0.7110340629819992]
+ quaternion: [0.5307017144256975, 0.46741438714490635, 0.4674303991663966, 0.5304604631173941]
+ joints: [0.10030165095433968, 0.06584088709283259, -1.2976500407816167, 1.6622473287544532, -0.37098781044995754, 0.34781355518960805, -0.6018018400420576, -0.09904748377982578, -0.06544481296988632, -1.2976852228876372, 1.6619263683150796, 0.37110146289468804, 0.3477480761574192, -0.6018152550805871, -0.0003019407105120084, 4.555739932554888e-05, 0.2646937752102002, 9.161865235372928e-05, -8.559573783704712e-06, 0.26947991475390903, 0.0004270348566789865, 9.887970315557065e-05, 0.23326012321905, 0.0014502125183565567, 0.00172575937402643, 0.4104257494709456, -0.0005696340802183806, -0.001299799320692417, 0.298093013704532, -0.0001401243335734644, -0.00033738658672034767, 0.20046580749002343, 0.008646666866787271, -0.11135779996888027, 0.1342724977163687, -0.11683995277459243, 0.44052381767121007, 1.4293063901570344, 1.5564695586428159, -0.0005791684474391213, 0.0021023827894321397, 0.779401362253349, 0.00013872451949360865, 0.7790019249523398, -0.008477313389176086, 0.1113286820514927, -0.13460277297381873, 0.11713026503172874, 0.44026919787797547, 1.4292229083618941, -1.5564691839401825, 0.0005762029409875078, 0.0020985665066178593, 0.7794009742091197, -0.00013874366292699342, 0.7790018890798096]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.4104686716817022, 0.2307089061442757, 0.36224880402168297, 0.4125447479411606, 0.23100998026860592, 0.3624951052340618, -0.32796629532752286, -0.5249443788731726, 0.2943328237185734, 0.32857207412066847, -0.5245883661509362, 0.2941879088549444]
+ velocity: [-0.002936663285674171, 0.0003180073375982305, -1.3438157923315677]
+ angular_velocity: [0.4937674754643902, -2.2400397947619738e-05, -0.00392454918520935]
+ joints_velocity: [-2.0632001283157804, 1.8756925130067745, 0.805089868421137, -10.093909143031876, -0.13722689253632356, -0.3694225195388197, -0.08259848244729512, 2.0587810704636382, -1.8682851301183323, 0.8114359142843033, -10.098099313560716, 0.13649340717046698, -0.3695213709803842, -0.08267746132911974, 0.0015523646727455423, 0.0007406817152263895, -0.34080222069152577, 0.0005379265015939548, 0.00038558488110489, -0.09266440557500383, 0.0006231207199858852, -0.0009119118818382365, 0.10762847229394334, 0.02433517821323348, -0.004458583808413399, 1.320363553378343, 0.003475778613850966, -0.007277286884135733, 0.8901162338646792, 0.0023521555360314665, -0.001693523103765735, 0.44378798289389176, -1.9519992752815816, 0.8644680489527223, 2.1165841848675937, -1.4324967510833246, -4.87899088812763, -1.3914685545080177, 0.24087885426900782, 0.03402899883030506, 0.026463374244691914, 0.08995391437978388, 0.0004854989677719594, 0.08658217892450819, 1.9540373421240074, -0.8643245087723799, -2.1219373047606362, 1.4366322547279746, -4.877933802182706, -1.3898263823374033, -0.24082794474968203, -0.03421922674807496, 0.02654391948320011, 0.08996791896283224, -0.0004978000813240749, 0.08659804434961563]
+ appendages: [-0.4104686716817022, 0.2307089061442757, 0.36224880402168297, 0.4125447479411606, 0.23100998026860592, 0.3624951052340618, -0.32796629532752286, -0.5249443788731726, 0.2943328237185734, 0.32857207412066847, -0.5245883661509362, 0.2941879088549444, -0.0012832017589331911, 0.32515617538162594, 0.3397861121212629]
+ body_positions: [-0.15569106800425475, -2.5349889254349086e-05, 0.7110340629819992, -0.09504116637214736, 0.10187645504598895, 0.6261678208599085, 0.3019583769323544, 0.17233246866592916, 0.5886476158256305, 0.20236130849763695, 0.3283432041265337, 0.22763808045396888, 0.2749025456350871, 0.33851151784178873, 0.1472135527891053, -0.15569106800425475, -2.5349889254349086e-05, 0.7110340629819992, -0.09509030775516279, -0.1019975330358268, 0.626217265173287, 0.30190712835894734, -0.1724090576090147, 0.5885913970621871, 0.20239169732285878, -0.3281952536328057, 0.2274624029850706, 0.27494923016980855, -0.33828409487957756, 0.14704256681997202, -0.15569106800425475, -2.5349889254349086e-05, 0.7110340629819992, -0.14806039671011584, 0.0006007611467025286, 0.8242759938038238, -0.10216631252563468, 0.0011491855974502464, 0.9281530662790645, -0.03027337418643694, 0.0010857955196282344, 1.0166786108073564, 0.053581390828521855, -0.0022775574274348457, 1.052299256802824, 0.14039303078864954, -0.001291301605434842, 1.076431642057573, -0.03027337418643694, 0.0010857955196282344, 1.0166786108073564, 0.041389550227492064, 0.18182876816368712, 1.0606838702354349, 0.17488915413520106, 0.41250637682869984, 0.9857990034993843, 0.33835404437209327, 0.3391469628575618, 0.9565937755829307, 0.42008612930877876, 0.3024674175135712, 0.9419912259760786, 0.4517454422311682, 0.28817530478700465, 0.936328257633121, 0.42035344827883403, 0.29930244918702026, 0.951473394296678, -0.03027337418643694, 0.0010857955196282344, 1.0166786108073564, 0.041246404207114395, -0.1797166927545293, 1.0606722613209298, 0.17448439537433957, -0.4105070249017406, 0.9856688817563909, 0.33801872059019056, -0.3373185832986863, 0.9564234412927459, 0.419785522863361, -0.30072452376201375, 0.9418007855009033, 0.45145842907621975, -0.2864656707611189, 0.9361299906280193, 0.42005988036426045, -0.29756254144092137, 0.9512837489492563]
+ body_quaternions: [0.5307017144256975, 0.46741438714490635, 0.4674303991663966, 0.5304604631173941, 0.6341146167766719, 0.19506669411522132, -0.10442218798143008, 0.7409073115405898, 0.2833288463735187, 0.5998975972409939, 0.4769148846151334, 0.5765412651924123, 0.6995587735511801, 0.35946827340971005, 0.07921893753281492, 0.6124740343031554, 0.7746669839713176, 0.1359815472237748, -0.105865905213412, 0.6084344606223053, 0.5307017144256975, 0.46741438714490635, 0.4674303991663966, 0.5304604631173941, 0.7409164304743725, -0.10411911457223594, 0.19489119182064765, 0.6342077549072824, 0.5764000965572306, 0.47703341146612294, 0.5998025934510628, 0.2836175275316868, 0.6123693446889762, 0.07932590902833683, 0.3591797779405756, 0.6997864481115513, 0.6083654698031525, -0.10573679259920639, 0.135633338553149, 0.7747998343539763, 0.46446000941060356, 0.5332859024859258, 0.533415468265005, 0.46405924630773093, 0.3885902308902296, 0.5908681243908487, 0.5908954209959785, 0.3882204186818412, 0.31706965962542377, 0.6321658218956656, 0.6319635306995746, 0.31694684165917103, 0.1808006334837045, 0.6835349119929177, 0.6832660024418032, 0.18231490593454813, 0.07779630680817932, 0.7028028646583752, 0.7027661348779418, 0.0783302479487189, 0.007206791641727911, 0.7070366743656784, 0.7071005144363859, 0.007487705920574433, 0.3502236413533386, 0.6516224763044389, 0.61068092997203, 0.2824895594187077, -0.10596373751857444, 0.6048667688227094, -0.2686074744907257, -0.7421306508646466, -0.47643935840370366, 0.38742439192544076, -0.6892441590951072, -0.38451315641775946, 0.6424638271518739, -0.27131812073444833, 0.5378589070418203, 0.4736396354595688, 0.64288597467787, -0.27079821834665135, 0.5382777081439324, 0.47288799669486237, 0.6975635669784355, -0.006254719942357406, 0.6775747757075795, 0.23293426503541279, 0.8900724155387778, -0.12773016287730457, 0.43671552544133063, 0.027122876447841053, 0.28260863253161683, 0.6109478984663969, 0.651369359804546, 0.3501328080684795, -0.7422521541911934, -0.26843235676279076, 0.6048859710371459, -0.10544558541615526, -0.38474846188404915, -0.6891754578019678, 0.387798340444582, -0.47604438510472413, 0.47396590177955555, 0.5378975122950438, -0.2716160078566914, 0.6420648989729075, 0.47321631684000076, 0.5383160712085013, -0.27109696496281444, 0.6424862337381078, 0.23322352558465925, 0.6777349497840833, -0.006683068873709895, 0.6973072934836241, 0.027574944641573812, 0.4368158403048195, -0.1280916351896783, 0.8899573456527246]
+ }
+ props {
+ position: [1.0, 5.9450419124621805e-22, 2.798895]
+ quaternion: [0.4984992508181226, 0.5004992494695555, 0.5004995771267048, 0.5004989225864246]
+ velocity: [-1.2287624621502384e-20, 7.119594265988146e-20, -1.9620000000000009]
+ angular_velocity: [0.010006383774561442, 0.01, 0.009993612249586828]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.14905274033694257, 1.76461139388136e-06, 0.6228408188884561]
+ quaternion: [0.5152097099740391, 0.48436770563899, 0.4844736018554525, 0.5150069995521342]
+ joints: [0.05463906663707841, 0.11504504385611591, -1.2586338549763343, 1.1633499301757235, -0.3657312654777581, 0.324666709199002, -0.606571199258106, -0.05415004338549242, -0.1146468546677924, -1.2583492948687767, 1.1630979133263877, 0.3657675463451495, 0.3246338832102315, -0.6065813898059011, -0.00011299664839562312, 2.9908340170470763e-05, 0.22081776128212155, 0.00010606235165363202, -3.3991150767127497e-06, 0.2275634700866528, 0.00032608455859424495, 4.1288105470084586e-05, 0.20846152928125558, 0.00178513096787077, 0.0014424157122995185, 0.4252014161208318, 5.539673924303481e-05, -0.0015416924321737238, 0.3100793838160297, 9.429944348505508e-05, -0.00037721036265731537, 0.2080092710527104, -0.01273389870356335, -0.03824908887455516, 0.17600426378033632, -0.119040771075688, 0.17581430584457902, 1.3502826316147345, 1.5612502452181398, 0.0024634695315846116, 0.001005315383512543, 0.7810995132713466, 0.0002470941243886318, 0.7808859661327122, 0.012774716559865291, 0.03812476982049563, -0.17636120726394577, 0.11924330252938306, 0.1754618747427401, 1.3503219748854949, -1.5612485126262257, -0.0024698295656321013, 0.0010080590578772002, 0.7811006992228876, -0.0002472101098887896, 0.780886744953391]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.447422110209054, 0.2810766213979492, 0.24599314979658615, 0.44943593063875603, 0.2814106698363178, 0.24624747657984128, -0.35384757498147024, -0.5587384923610282, 0.4705868576584394, 0.35419165528573204, -0.5584451846205485, 0.47061574031273046]
+ velocity: [0.2909304624544653, 0.0007251901352216487, -2.2128228032030166]
+ angular_velocity: [1.578546499406739, 0.003925769910054199, 0.0015821320828657815]
+ joints_velocity: [0.42724529261258587, -0.0022811842454657227, 0.34218840397472217, -7.965577521727297, 0.32848824943583144, -0.5072126883412631, -0.10626171615913542, -0.4485293956986247, -0.0030249878828598475, 0.3460632806031642, -7.958320904015658, -0.33004450447811035, -0.505971442201712, -0.10607490908149662, 0.005631261351970911, -0.0014483688426715391, -1.1516774616084005, -0.00029064676471282435, 0.0006798897678559478, -1.5405296304969902, -0.005793987472433036, 0.0003325850090176838, -1.2679893662202997, -0.003415717950547062, -0.0019122479757927291, -0.12030854221613563, 0.017816526128055028, -0.004663600328601208, -0.08618647170012385, 0.006147300090831453, -0.000751183441436881, -0.02393409502371011, 0.4532664277721973, 1.6611382954319396, -0.8980598810869039, 0.9247018455511251, -4.281273433645359, -1.1316721006028823, 0.026394858369526215, 0.07902406876576713, -0.041679787064090604, 0.00984161837119804, 0.003000784587813996, 0.016092376393339297, -0.45611241090835397, -1.6643916299403059, 0.9040867169285158, -0.9304903709673658, -4.285646047253225, -1.1303783535398029, -0.026384324669643487, -0.07901986066055805, -0.041556153644219315, 0.009875927183472494, -0.0029965112318921837, 0.016105603505734055]
+ appendages: [-0.447422110209054, 0.2810766213979492, 0.24599314979658615, 0.44943593063875603, 0.2814106698363178, 0.24624747657984128, -0.35384757498147024, -0.5587384923610282, 0.4705868576584394, 0.35419165528573204, -0.5584451846205485, 0.47061574031273046, -0.0014701333052544317, 0.35498838223985624, 0.312961360455793]
+ body_positions: [-0.14905274033694257, 1.76461139388136e-06, 0.6228408188884561, -0.09405688031581885, 0.10190585093599559, 0.5342085446320933, 0.2940088222518359, 0.19137792568575857, 0.4608586205990337, 0.3549728823657601, 0.3539783996098609, 0.09420990656461259, 0.45346624054333573, 0.35157331321028207, 0.04809044978469834, -0.14905274033694257, 1.76461139388136e-06, 0.6228408188884561, -0.09407853969284559, -0.10196813840026812, 0.5342708261578206, 0.29396323994405327, -0.1914191671310878, 0.46076882795207164, 0.35488682610401007, -0.35406088244707734, 0.09413167852842885, 0.4533717321400425, -0.3516407898069656, 0.04799496024410358, -0.14905274033694257, 1.76461139388136e-06, 0.6228408188884561, -0.13903005188004797, 0.0006142871657769529, 0.7358962238362066, -0.09530395280991744, 0.0011554011417327491, 0.840704378930698, -0.02751876938980012, 0.0011006878610017637, 0.9324133564112819, 0.05519851521077265, -0.0022877629709604818, 0.9705988694910105, 0.14154172151637126, -0.0013851347279333751, 0.9963601104999492, -0.02751876938980012, 0.0011006878610017637, 0.9324133564112819, 0.028808364896153946, 0.18526849858148545, 0.9838869534359334, 0.07951512902397279, 0.4495038151560815, 0.9186905716696976, 0.2575415503753189, 0.4481838206361939, 0.8831916954719862, 0.3465543687840667, 0.447523826284752, 0.8654423355920539, 0.3810499236360825, 0.4472627112688162, 0.8584685055281631, 0.3484041522177253, 0.4442662008076925, 0.874714130321219, -0.02751876938980012, 0.0011006878610017637, 0.9324133564112819, 0.02871589554685275, -0.1830937212800351, 0.9838928849031661, 0.07918649262173537, -0.4473542606123047, 0.9186155317726752, 0.25720424080768717, -0.44616509266549814, 0.8830685697871926, 0.34621272265284847, -0.44557051131233094, 0.865295167119328, 0.3807065268702235, -0.445334702951839, 0.858311782811521, 0.348067650469675, -0.44231674147305605, 0.8745672879461609]
+ body_quaternions: [0.5152097099740391, 0.48436770563899, 0.4844736018554525, 0.5150069995521342, 0.6218514947015582, 0.20998091054607004, -0.053162486282465046, 0.7525838729310828, 0.4042154126575698, 0.5171084648566386, 0.36906822936359546, 0.6580253625948638, 0.7438315688620233, 0.2390137032717638, -0.0876740966202629, 0.6179808246170744, 0.7812665990403906, 0.005954360694749555, -0.26823726645250545, 0.5635918875380181, 0.5152097099740391, 0.48436770563899, 0.4844736018554525, 0.5150069995521342, 0.7525142339053323, -0.05308072191770889, 0.21016854309929914, 0.6218793679019136, 0.6579687571039999, 0.3690153790080463, 0.5172296256401482, 0.4042007905617182, 0.6178733435161241, -0.08767913889938427, 0.23910896883658128, 0.7438896430223876, 0.5634864393095057, -0.26821284969260256, 0.006023952898818796, 0.7813505051931183, 0.4587283358808673, 0.5381542282154186, 0.5383020439269194, 0.45846401169332607, 0.39463546179508785, 0.5867832725251376, 0.5868471713366615, 0.39440923020718865, 0.3313584292279985, 0.6247365558165747, 0.6246183635646739, 0.3312970375717027, 0.1912806145014786, 0.6807418346087877, 0.6803054010039041, 0.1928389029587033, 0.08435353197156246, 0.7023594624492737, 0.7016787342190436, 0.0849860052988185, 0.01110323594787288, 0.7073836945374817, 0.7066480355076796, 0.011557714615178752, 0.34542070789315826, 0.6269285365108551, 0.6220914922976984, 0.3172496176164279, 0.005548541611422721, 0.635181078141149, -0.15079678708330807, -0.757479069455246, -0.39266172955722056, 0.49930206737374744, -0.5911449852316979, -0.49705313407406415, 0.5345158608741355, -0.35072962038210975, 0.4668280431780358, 0.6110262727274245, 0.5339387736221588, -0.34988600192053276, 0.4675671035272488, 0.6114492418689135, 0.6269336280578909, -0.12027075424740707, 0.665136047717995, 0.3874057429995266, 0.8636257142976262, -0.18756762974240548, 0.41297326255297373, 0.2200502085731429, 0.31734606524436926, 0.6222440768520763, 0.6267781134678168, 0.3453302479861108, -0.7575465080938225, -0.15062985885116767, 0.6351365089720573, 0.005962270383094831, -0.49719848220135526, -0.59106660842864, 0.49951814199780287, -0.3923207355116426, 0.6112237108091584, 0.4668417987301049, -0.35089924513020954, 0.5341666688124334, 0.6116472968978075, 0.46758312099447374, -0.35005356826434364, 0.5335879576078094, 0.38758239220456214, 0.6652264878177329, -0.12055888577380763, 0.6266730919862851, 0.2203621690245692, 0.4130485210441265, -0.18781064442977857, 0.8634573501713436]
+ }
+ props {
+ position: [1.0, 5.946848916082991e-21, 2.68730625]
+ quaternion: [0.4981238297230253, 0.500623827088952, 0.5006243367099742, 0.5006233189800253]
+ velocity: [-2.3852447794681095e-20, 1.2902005852577502e-19, -2.4524999999999992]
+ angular_velocity: [0.01000797908075115, 0.01, 0.00999201467512285]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.1473166894954605, 1.565848638863668e-05, 0.4936939802290103]
+ quaternion: [0.5248923817040687, 0.47379416043704575, 0.47409340592219545, 0.5246356103295676]
+ joints: [0.17357708905669528, 0.09017994009822164, -1.4183432988667217, 1.421497608241757, -0.3552122905589304, -0.06559259938402953, -0.5696516145630768, -0.17389816929064833, -0.08957688188577981, -1.4201412361889572, 1.4223069995880175, 0.35896440178980943, -0.06049495251882199, -0.5708757763077136, 0.00026673614928087916, -0.00039627010457463335, 0.25448492197023703, 0.00020258451391837358, 2.6288839100339257e-05, 0.18547722368030525, -4.4263531450641435e-06, 0.0003181376542556298, 0.13046411279112674, 0.0006376997022474405, 0.0019929175683550413, 0.3761103297948858, 0.001184803671953219, -0.001573591957676397, 0.27169109967437594, 0.0004237352274605827, -0.00034103009296100975, 0.18956907382786228, -0.0013966084879039418, 0.01788497192380461, -0.017910401434399807, 0.006286458380094605, 0.1332328823582221, 1.3372278599090979, 1.5621331936726135, 0.0034741146847604315, -0.0017382297460971877, 0.7810241518145326, 0.0002343405691583489, 0.7810589390320872, 0.0013138735516385544, -0.017885350701564093, 0.018061629473386083, -0.006357085842854059, 0.13331894891264845, 1.3371813807902466, -1.5621322979705428, -0.0034776027320356593, -0.0017379183407539582, 0.7810245908136529, -0.00023438215218742327, 0.7810592693471498]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.45805758741912234, 0.373048622607841, 0.21544541023825978, 0.45969937886746454, 0.37352506948698216, 0.2155205890771909, -0.3658743290635875, -0.4780807081055464, 0.4473240166802205, 0.3662770603152158, -0.47869134602278807, 0.4467935049936381]
+ velocity: [-0.09287950007436103, -0.0014783441804631311, -2.736852185479247]
+ angular_velocity: [-3.9979262712439025, 0.02624365630885249, -0.0256497667949595]
+ joints_velocity: [1.6131466752363068, -0.753555467015253, -2.442249298158421, 7.39386929107786, 0.5500734903235988, -2.1376984712757805, -1.5542347916221533, -1.6228143505531167, 0.7590749830472244, -2.4512142196803466, 7.374920306828262, -0.3973598573810468, -1.9402742495700396, -1.3837148830017676, 0.014769697356973255, -0.015018421499891577, 1.538697223546732, 0.012503679093366514, -0.007275079381424011, 0.769670343208217, 0.005078665329560932, 0.00074639690253833, -0.2678253547602788, -0.026437651214018587, 0.020013120861693674, -1.226783737699908, 0.018376651434839502, 0.002111081136070489, -1.110901147893416, 0.0047315304311376025, 0.0007841037115240885, -0.6014242803274932, 0.4184753785453915, 0.279441624705452, -4.3520552963158305, 2.8662738244707384, 1.2396611695492474, 0.14661053827631731, 0.022491822764987113, -0.04287437507262312, -0.06380639112971476, -0.00850361446589953, -0.003423047061447511, -0.004656320754843999, -0.42221707269012426, -0.2797490658592093, 4.3668239755263025, -2.8717151578838105, 1.254779420766303, 0.14343377887593733, -0.022537350132719073, 0.04299433018061929, -0.06396388681336174, -0.008557511680240359, 0.0034210327013149827, -0.004684181373730079]
+ appendages: [-0.45805758741912234, 0.373048622607841, 0.21544541023825978, 0.45969937886746454, 0.37352506948698216, 0.2155205890771909, -0.3658743290635875, -0.4780807081055464, 0.4473240166802205, 0.3662770603152158, -0.47869134602278807, 0.4467935049936381, -0.0015507215667952723, 0.3868169900481825, 0.2979605444633763]
+ body_positions: [-0.1473166894954605, 1.565848638863668e-05, 0.4936939802290103, -0.08880815535597676, 0.10190535496708765, 0.40732321750868977, 0.3113364523671971, 0.16384458583311506, 0.41264088756263734, 0.3457552915313331, 0.366059456239525, 0.06262474515471406, 0.4518896596570148, 0.349639998849856, 0.04531498891867318, -0.1473166894954605, 1.565848638863668e-05, 0.4936939802290103, -0.08880526802519786, -0.10196861333470543, 0.40743686847803345, 0.3114179304459619, -0.16334235258852148, 0.413362116196973, 0.34623144658524996, -0.3660914440421648, 0.06369422339331099, 0.4522400799806517, -0.349549816130728, 0.0457422170698121, -0.1473166894954605, 1.565848638863668e-05, 0.4936939802290103, -0.13804385604925684, 0.0006067391805102167, 0.6068134712072473, -0.09947358144749308, 0.0011105547956660549, 0.7136268318205493, -0.043821771206079704, 0.0010602754906187976, 0.8131669951930367, 0.0309451421416804, -0.0022132730593563774, 0.8652338438420033, 0.1098393519941963, -0.0013000402639918546, 0.9087594134397187, -0.043821771206079704, 0.0010602754906187976, 0.8131669951930367, -0.0027302156020754947, 0.18508640726084077, 0.8778657327141967, 0.02916794950320807, 0.4599368096286671, 0.8869106224404478, 0.2091115192167165, 0.48099613120888673, 0.8984056643864623, 0.299082907582263, 0.4915257455964724, 0.9041531600310676, 0.33395936051394526, 0.49572435279546834, 0.906309666430941, 0.2991898253650098, 0.4862687203386696, 0.9126591728471029, -0.043821771206079704, 0.0010602754906187976, 0.8131669951930367, -0.002744106273110704, -0.1829711264475961, 0.8778595627523006, 0.029154520260652295, -0.4578202828207072, 0.8869406152930952, 0.20909642218314686, -0.47888586355921453, 0.8984502867334441, 0.29906697665686127, -0.48941860751215244, 0.9042050970929816, 0.33394310630631524, -0.4936184918862973, 0.9063643430906139, 0.29917358281708967, -0.4841605533363562, 0.9127104778258467]
+ body_quaternions: [0.5248923817040687, 0.47379416043704575, 0.47409340592219545, 0.5246356103295676, 0.6287716559182714, 0.1860417282156919, -0.16783183379125308, 0.7361162650305854, 0.35515552432960107, 0.5512076500393047, 0.35304801004409053, 0.6673767921348639, 0.763646892631667, 0.1507580777354085, -0.21374582066586162, 0.5902780273052699, 0.7752415825807358, -0.06989322761321287, -0.3709965214585337, 0.5064355896233907, 0.5248923817040687, 0.47379416043704575, 0.47409340592219545, 0.5246356103295676, 0.7356495040985692, -0.1689939040117702, 0.1860114533196618, 0.629015585466924, 0.6676387306323325, 0.35213301156417864, 0.5514876130471275, 0.35513698791025944, 0.5888971240481273, -0.21263378658478205, 0.15338723025688883, 0.7644994491067879, 0.5051965090745045, -0.36985056053556575, -0.06808501296044142, 0.7767570283561254, 0.46052728151299394, 0.5367411996012352, 0.5366677082677986, 0.4602295932852078, 0.4087850988180358, 0.5771253429250396, 0.5769383333037225, 0.40861135704077944, 0.3702089324471548, 0.6024728013091102, 0.602419007002211, 0.3702204883870712, 0.25037798539586004, 0.6607526477568868, 0.6613179466967897, 0.25174466486524855, 0.15887429833150135, 0.6892078810373907, 0.6886850614044265, 0.15957550013709668, 0.09299951671173957, 0.7013307622709067, 0.7005161248641377, 0.0936130898391022, 0.3650617434861338, 0.5987195755783619, 0.6061283207000511, 0.3753308568549445, -0.05149298714456372, 0.7264368331157152, -0.030777567297426152, -0.6846099189458132, -0.49072303846923027, 0.538101198426268, -0.4485395682356084, -0.5181218539057724, 0.5914135692694055, -0.3828914931434093, 0.3036887137209445, 0.6414025722749723, 0.5899659839137372, -0.38287628378806027, 0.30379490028559264, 0.6426931987331185, 0.6912964019610345, -0.12947353203796982, 0.5255723456488898, 0.4786643903889914, 0.9103172130894346, -0.13348568924341406, 0.2642255437572739, 0.2892905189363599, 0.3753344731767828, 0.6061572081673589, 0.5986905798099226, 0.3650576144647839, -0.6845877800190568, -0.03073485014709691, 0.7264561478240453, -0.05154033103952445, -0.5181413848375048, -0.4484802837178663, 0.538098411075629, -0.49075965648461606, 0.641420854035001, 0.3036216208453524, -0.3828833800159702, 0.5914334423008198, 0.642712439471694, 0.3037285435380817, -0.3828676809455204, 0.5899847717322703, 0.47870732671463107, 0.52551841391254, -0.12945827331613072, 0.6913105289655191, 0.28932236470506517, 0.2641759887978171, -0.13344530410940045, 0.9103273955205975]
+ }
+ props {
+ position: [1.0, 1.2860444769298893e-20, 2.5511925]
+ quaternion: [0.49774831526158775, 0.5007483107097661, 0.5007490423869598, 0.5007475816457921]
+ velocity: [-2.3852447794681095e-20, 1.7527935121848989e-19, -2.9429999999999974]
+ angular_velocity: [0.010009574131879268, 0.01, 0.009990416845996015]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.14604111666103736, -0.00018877085154157645, 0.35643625979534055]
+ quaternion: [0.5891233859844324, 0.3901311107387944, 0.39165093019900443, 0.5893563450020767]
+ joints: [0.22158393153842632, 0.036603562664869534, -1.4522997639318025, 1.7462638304959954, -0.35681033231797116, 0.010037258058905665, -0.6165356092194828, -0.2222259265262575, -0.03579194798427226, -1.4520251957334307, 1.7471114343487215, 0.3592394434636958, 0.015718770367110717, -0.6147684946549435, 0.00025855348163278944, -0.0005656292457244406, 0.296250753532087, 0.000497265088503667, -0.00032457354260592843, 0.2407981621194745, 0.00046197195997951145, 7.905705387995505e-05, 0.16836139227633481, 0.0006128805010292822, 0.002895203825016563, 0.38951651507856716, 0.0015012395570405935, -0.0016339468936200027, 0.26519674859881864, 0.0005098687450545064, -0.0003340625142049851, 0.18027951737939493, 0.03912663843465701, -0.019733540025076206, -0.07020408723785494, 0.040330934875337475, 0.27690225023608106, 1.370150989406834, 1.5628841595242355, 0.001076869475313195, -0.003062303962548386, 0.7807864576963852, 9.881068339107794e-05, 0.7808696691727007, -0.03883380031239386, 0.019404283608662587, 0.07022106068429756, -0.04039589072360654, 0.27613873324634075, 1.3700848313825078, -1.5628845811543277, -0.0010750105212807702, -0.0030654609870958346, 0.7807860573871683, -9.860149775934591e-05, 0.7808694190885923]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.44196011124907825, 0.3575381252990847, 0.31487783735167996, 0.442695811281038, 0.3584499384187497, 0.3163225084383008, -0.34631837032228485, -0.45319316060485765, 0.33337171171132507, 0.34630581123153387, -0.4532349309600191, 0.3338118795363017]
+ velocity: [0.1390832947325073, -0.004323007709096276, -2.6794125414779377]
+ angular_velocity: [-6.663246541201264, 0.020843403638440713, -0.0042517126540332914]
+ joints_velocity: [0.676381651325779, -1.069310264252119, 0.3710051612895596, 5.3744950405148835, -0.45470785847315875, 2.200191886791215, 0.1846890089454169, -0.697976480460788, 1.0896486624026003, 0.4154099081674788, 5.382901718531191, 0.36583306544370975, 2.0904724533527212, 0.1719966533936872, -0.006284437787564234, 0.005129784599658711, 0.15581079130012332, -0.0023483333244512863, -0.0013153127165024437, 0.8329721432705096, 0.003781944860257529, -0.003798174730052998, 1.0088479047330707, 0.012440564357400378, 0.012881577227221818, 1.2305459559804621, -0.003402620756545851, -0.0032176011962546276, 0.6475118246547676, -0.0006738864898816224, -0.00013235079636635238, 0.21363015817896067, 0.5755792550988573, -1.3560130095784768, 1.4854921101898033, -0.9840904803410016, 3.927772205623189, 1.0093217936298355, 0.005408867996966984, -0.031050201923971215, 0.01647833142524033, 0.0011102911666826256, -0.000852031111216332, -0.0006144888487652363, -0.5656506967330284, 1.3504427777412384, -1.5036429178939714, 0.9918269629608354, 3.893956748737359, 1.0091091182057237, -0.005412690088334313, 0.031070074655799522, 0.01646110889091778, 0.0011300587590359825, 0.0008629757683079747, -0.0006075644737673376]
+ appendages: [-0.44196011124907825, 0.3575381252990847, 0.31487783735167996, 0.442695811281038, 0.3584499384187497, 0.3163225084383008, -0.34631837032228485, -0.45319316060485765, 0.33337171171132507, 0.34630581123153387, -0.4532349309600191, 0.3338118795363017, -0.0018563985296490032, 0.3465971333930139, 0.33407967310532827]
+ body_positions: [-0.14604111666103736, -0.00018877085154157645, 0.35643625979534055, -0.064938581234355, 0.10176167717991323, 0.2909141081020439, 0.3177962160053178, 0.14076807487837217, 0.4172985570880904, 0.3372232256967768, 0.3463608100741398, 0.06809800041874259, 0.4437500804016981, 0.34591477379064917, 0.04606256038069433, -0.14604111666103736, -0.00018877085154157645, 0.35643625979534055, -0.06464037370704956, -0.10211184083726259, 0.291242131837413, 0.3183000110578527, -0.13939744928213535, 0.4175231602780132, 0.3378145568116483, -0.3462625644085654, 0.06907974034191589, 0.44417304420576137, -0.34571093672173203, 0.046247784545456846, -0.14604111666103736, -0.00018877085154157645, 0.35643625979534055, -0.16570306235851745, 0.00048524332492200494, 0.468218671882228, -0.14918736439229488, 0.0010804703826858345, 0.5805747635905775, -0.11039191017455482, 0.001085691531352948, 0.6878141376103615, -0.04419418636473438, -0.0020747872736710834, 0.7504211413401091, 0.026991073503371155, -0.0010692975470996114, 0.8056592922787217, -0.11039191017455482, 0.001085691531352948, 0.6878141376103615, -0.07165823246260514, 0.18150966944800162, 0.763226007741923, 0.005372331929988802, 0.4434576519372174, 0.8089599477448124, 0.18010161569214084, 0.42253043172124644, 0.8535362659642171, 0.2674658725712695, 0.4120668677247104, 0.8758243268535814, 0.3013447308320544, 0.4081185800131239, 0.8845007188717066, 0.26481788223082864, 0.4062552458051301, 0.8835193380829939, -0.11039191017455482, 0.001085691531352948, 0.6878141376103615, -0.07144836837728727, -0.17928319677824078, 0.7632497024014187, 0.005689855249315062, -0.44119986284347984, 0.8089815418789907, 0.1804044128712686, -0.42021047976833126, 0.8535863448612901, 0.26776130671274595, -0.40971583447917725, 0.875888648069338, 0.30163732170328894, -0.4057555476477195, 0.8845706713319819, 0.2651099378451256, -0.40390573630375926, 0.8835836465956418]
+ body_quaternions: [0.5891233859844324, 0.3901311107387944, 0.39165093019900443, 0.5893563450020767, 0.607799185859568, 0.09034520575714904, -0.2870844078251171, 0.7348472196646366, 0.32123260070946885, 0.5238241334420229, 0.3787147779348162, 0.692093209357618, 0.7215157049881905, 0.1839272177516369, -0.1894431255384009, 0.6400759081638505, 0.7433089509523788, -0.04365704305765328, -0.3747173874045041, 0.552424425240484, 0.5891233859844324, 0.3901311107387944, 0.39165093019900443, 0.5893563450020767, 0.7332751846287254, -0.28849595309598586, 0.09135361021740869, 0.608877743520974, 0.6920053505003827, 0.3768965457839796, 0.5254345672281167, 0.320930061254394, 0.6382481473937565, -0.18838574134084562, 0.18777078780134185, 0.7224211002291955, 0.551333132447433, -0.37266830058082967, -0.03961079709090301, 0.7453731277432271, 0.5250970346308781, 0.47302436292889805, 0.4741221731842253, 0.5250992488974214, 0.4644038783611685, 0.5328637902109189, 0.5335638333247427, 0.4644296013530633, 0.41781875319308703, 0.5701182537813069, 0.570630692456461, 0.41805894203113003, 0.2987321662843437, 0.6395923984149322, 0.6413549680606809, 0.3005735544933735, 0.21174827857099257, 0.6742330479510498, 0.6747063298267985, 0.21294091214351102, 0.15023513098244604, 0.6907688360209548, 0.690923875089999, 0.15130108866614966, 0.4150598403216197, 0.5853482664878671, 0.5553006014783564, 0.4203974046894394, -0.15276882248838003, 0.7777702373693378, -0.059692943693782204, -0.6067717010631993, -0.6104194977501374, 0.5056212930924809, -0.43014805424680586, -0.43209697541830205, 0.7019636302761029, -0.37438603314846264, 0.25472803845560016, 0.5497233725952676, 0.7010936842677277, -0.3753227449834999, 0.2540869968386778, 0.5504909454436616, 0.7911728897809431, -0.08028060474184923, 0.44445983889328994, 0.4123783876355615, 0.9627553418744064, -0.0838802567980898, 0.17574802568835707, 0.18756035209613553, 0.4201942460065122, 0.5549287618617761, 0.5856966472694485, 0.415271359786468, -0.6066789474022118, -0.05986722622612885, 0.7778062318014046, -0.1528856951249325, -0.43192910676274604, -0.4302100379212489, 0.5055954126958325, -0.6105160511107507, 0.5495545209589006, 0.25476263557334217, -0.37440526585715417, 0.7020730196409712, 0.5503219542334524, 0.25412064714945476, -0.37534347129310686, 0.7012030532867474, 0.4122093946687013, 0.44442656487416077, -0.08025830936361361, 0.791281901195213, 0.18736841447902527, 0.1756939920171993, -0.08389614400555817, 0.9628011920670203]
+ }
+ props {
+ position: [1.0, 2.185570879368697e-20, 2.39055375]
+ quaternion: [0.49737270751059937, 0.5008727002823389, 0.5008736941442018, 0.5008717105704764]
+ velocity: [2.2406844898033758e-20, 1.2902005852577502e-19, -3.4334999999999956]
+ angular_velocity: [0.010011168927905148, 0.01, 0.009988818762247047]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.13772128034548192, -0.0003153326458560554, 0.23124390693790844]
+ quaternion: [0.6451675365771709, 0.28789543479806806, 0.28973571823620603, 0.6456998388938362]
+ joints: [0.24439047259200408, 0.0004897356451942789, -1.4015317053172756, 1.9447900235198092, -0.39034381913724536, 0.0864605019573331, -0.571819741120574, -0.24570324527924742, 0.0013626110719718422, -1.3997716274546983, 1.9458506694400235, 0.3902582946020553, 0.08796798843646911, -0.571505674104913, 0.00010206280532636057, -0.0002891977570821109, 0.2835840679358296, 0.000366645035108718, -0.0002132944478806393, 0.26012006606748705, 0.0004550203001227616, 5.49917132883616e-05, 0.20538933548773392, 0.0008297181094717135, 0.003077015752089414, 0.43258295087542864, 0.0009558830560233747, -0.0016594613525204498, 0.29673986828208576, 0.000337450050740188, -0.000336697498213008, 0.1954689000179158, 0.029805607309010718, -0.08337908032710968, 0.06374312820213646, -0.05048373776254516, 0.42058172252017745, 1.4139460862750821, 1.5628077354109948, 0.0005802595943222307, -0.0005057010780621528, 0.7811475500333438, 0.00012837146368164907, 0.7810117474640534, -0.02930918029401539, 0.0830571362174467, -0.06428449101810291, 0.05074510806011148, 0.4192263563169192, 1.4136422648716203, -1.562806461764115, -0.0005880532360184639, -0.0005144402484424855, 0.7811469629109008, -0.00012834263541613755, 0.7810119084824778]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.42016835049666734, 0.2744801442650083, 0.363538761825885, 0.42122843489268996, 0.2753047485519404, 0.36479841964018667, -0.3345979390715314, -0.43619240313520247, 0.23779032887300106, 0.3343740711720341, -0.4362113385225768, 0.23888992442797874]
+ velocity: [0.13498524337002965, -0.0012259080620148327, -2.3568552764322415]
+ angular_velocity: [-6.384043914050266, 0.0062719609621664524, -0.006865946956795592]
+ joints_velocity: [0.25030579264605823, -0.3695757910107399, 1.3448625464386383, 2.8926359844822964, -0.6935476080276252, 1.1785758953670553, 1.2196848625129495, -0.25541371696534576, 0.38056798880508874, 1.3658722147328726, 2.894228390748524, 0.671646474291525, 1.1320301003950117, 1.1941867559002064, -0.0006927367105648418, 0.004777927718693852, -0.4358812599850629, -0.0008859327750862912, 0.003273767894602094, 0.07945102207427282, -0.00036652763429039436, 0.0013936856326338517, 0.4333331225221482, -0.004065125606943976, -0.002618931498428795, 0.21501748819531724, -0.012176098530256566, 0.0021566369037275173, 0.31382960681138294, -0.00446496097383663, 0.00013030059258597632, 0.21051670311123244, -0.6631761489189109, -0.8892174113196271, 2.7674934332151255, -2.0000129754281946, 1.2649466194323424, 0.5749629984734863, -0.004078612596504411, 0.0012441246394157305, 0.06246766145281057, 0.010115833045224868, 0.0009595546382477575, 0.004421977113756022, 0.6642481255419408, 0.8929480706828985, -2.770903984810195, 2.002245319930657, 1.273617131356737, 0.5710417483137555, 0.0041386092850853986, -0.0015398986937947715, 0.06231985266774661, 0.0100888126413032, -0.0009759768271846756, 0.0044271891723452516]
+ appendages: [-0.42016835049666734, 0.2744801442650083, 0.363538761825885, 0.42122843489268996, 0.2753047485519404, 0.36479841964018667, -0.3345979390715314, -0.43619240313520247, 0.23779032887300106, 0.3343740711720341, -0.4362113385225768, 0.23888992442797874, -0.0015747499645076533, 0.3251035651305464, 0.34148148251194604]
+ body_positions: [-0.13772128034548192, -0.0003153326458560554, 0.23124390693790844, -0.03985536126241482, 0.10171945097037322, 0.19552511239797712, 0.2953174870721631, 0.1410794293827344, 0.41933397474188605, 0.3304863349263793, 0.33453450148769714, 0.06447471043210057, 0.4368646966695563, 0.3449615630954711, 0.04425959533409822, -0.13772128034548192, -0.0003153326458560554, 0.23124390693790844, -0.03949855388493839, -0.102153800814312, 0.19594674165895692, 0.29634805279798476, -0.13927035512721495, 0.4191279935608915, 0.33082460289350657, -0.33443799677345376, 0.06513961889758718, 0.4371264439795752, -0.3447785338340579, 0.04448236316574944, -0.13772128034548192, -0.0003153326458560554, 0.23124390693790844, -0.19379422251257739, 0.00033822711169889235, 0.3299239559450523, -0.2138288855356832, 0.0009732084783048382, 0.4417059893616595, -0.20689319563459674, 0.0010501812136663455, 0.555535954292561, -0.15771584094999966, -0.002013144673587307, 0.6322429813541589, -0.0997249762113982, -0.0008516718677312957, 0.7012017816554068, -0.20689319563459674, 0.0010501812136663455, 0.555535954292561, -0.17960194543225055, 0.18104415272331867, 0.6367604434302306, -0.04989575851802261, 0.4219938034460922, 0.6787320712488507, 0.10175320456062883, 0.35940281783701045, 0.7564513083375763, 0.17757735195364072, 0.32810746294668075, 0.7953107556338405, 0.20699002563220004, 0.315996345763991, 0.8103726698224734, 0.17213938307060003, 0.32400765775052165, 0.802633329756417, -0.20689319563459674, 0.0010501812136663455, 0.555535954292561, -0.17904852166038873, -0.1787458730622758, 0.6370105684442865, -0.048812987231085375, -0.41940358368076913, 0.6790179313796014, 0.10275232585922447, -0.3565132107373771, 0.7566586175877836, 0.1785346484423816, -0.3250681628395682, 0.7954787896168567, 0.20793131932110773, -0.31289935346606296, 0.8105254374508145, 0.17309223290013454, -0.32097670668657374, 0.8028027297868493]
+ body_quaternions: [0.6451675365771709, 0.28789543479806806, 0.28973571823620603, 0.6456998388938362, 0.5592601969007844, -0.0010200349301371923, -0.37232883657776844, 0.7406741720513026, 0.315886531694501, 0.4615069960496612, 0.40223145069251787, 0.7248702309827807, 0.6869006771544904, 0.1751079064204591, -0.1744670150044856, 0.6834222278424743, 0.7084020162905641, -0.02572746551289576, -0.36013051750769076, 0.6064739822889265, 0.6451675365771709, 0.28789543479806806, 0.28973571823620603, 0.6456998388938362, 0.7388020569270336, -0.3736575365953103, 0.0008720151883868099, 0.5608482910871888, 0.7247012509731243, 0.400320535086673, 0.46405207466916387, 0.3149718051209561, 0.6819510115883214, -0.17476330852910207, 0.17828911217901017, 0.6874689784012421, 0.6050354501823039, -0.35990484886573815, -0.022724208048967798, 0.7098480218710019, 0.5980017842455658, 0.37628793850612513, 0.3779685165404859, 0.598281751249455, 0.5440628974288245, 0.4507898720701073, 0.45224696897868544, 0.544294712384716, 0.4948395596498306, 0.5042625809747993, 0.5055890947495151, 0.4952097807036547, 0.37424855274815083, 0.5979590105301641, 0.6008341676209316, 0.37600710832504736, 0.28196918892956385, 0.6473526910199894, 0.6491798589395243, 0.28283101080122314, 0.2175015037605528, 0.6719400314459697, 0.6735182325407914, 0.21808915708355653, 0.507737511984677, 0.5322404937244727, 0.4772717846369875, 0.4807643074864905, -0.19387775227359508, 0.8231005945344175, -0.16503817418774522, -0.507621147620154, -0.6820442150502575, 0.499898850034736, -0.4552015517451285, -0.2787622207890695, 0.7776638608413446, -0.4091207327922798, 0.26006092273932363, 0.4002842265251198, 0.7774442415614229, -0.4092418544272058, 0.2600783324643262, 0.40057560688407023, 0.8747016773346191, -0.08243433641413095, 0.3929987935988911, 0.2713917907339535, 0.9877464056462806, -0.10989374943085506, 0.10504570804572184, 0.03529590897051716, 0.48031551719462445, 0.47611498451789724, 0.5333252120809759, 0.5081099719729628, -0.5076086409365561, -0.1658605039690307, 0.8227879552688983, -0.19453467951526593, -0.2782878198996944, -0.45577640178687795, 0.49933809223819414, -0.6822647803526153, 0.39967944775427, 0.26055715411344627, -0.40870402506880904, 0.7780278455014948, 0.39997522959124615, 0.26057443965415605, -0.4088274773273086, 0.7778051625532846, 0.27064786537837077, 0.39322888481618284, -0.08191401307243595, 0.8748776323433589, 0.03442042972290656, 0.105352448818611, -0.10967569105158792, 0.987768868884806]
+ }
+ props {
+ position: [1.0, 2.738152586612142e-20, 2.2053900000000004]
+ quaternion: [0.49699700654686785, 0.5009969957570299, 0.5009982919682395, 0.5009957057408674]
+ velocity: [1.6118472297617834e-19, 5.963111948670272e-20, -3.9239999999999937]
+ angular_velocity: [0.01001276346878814, 0.01, 0.009987220423916678]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.13815227342749042, -0.0007116269648282258, 0.1213042913458387]
+ quaternion: [0.6805025419337923, 0.18979015782416292, 0.19495501487570271, 0.6803591173707757]
+ joints: [0.23794748648832445, 0.006564998334316186, -1.3282968861288391, 1.9923297515278588, -0.3921009064262061, 0.16826823880946049, -0.5318633523585797, -0.23962373320435265, -0.005734564698284298, -1.3286877635082182, 2.003642228006561, 0.40475344741498315, 0.1772994508528556, -0.5543390005365412, 0.0010444169183577632, -0.0010625683910214883, 0.25471227868464535, 0.0014045055126212732, -0.0009629525421880857, 0.24346260695468985, 0.0013472666325416128, -0.000510792787764063, 0.20175670166996296, 0.0007935077771278979, 0.0024777063103889265, 0.4051995663690082, 0.0010047032517330001, -0.001573149409139616, 0.28616166692570405, 0.0003105039882299102, -0.00032322260334212096, 0.1938277320912592, -0.003655123731294004, -0.09033227652083493, 0.1335692308116171, -0.10322736143595397, 0.38909429650211463, 1.4114292733512765, 1.5625776222930527, 0.0011361073948640281, 0.0015584526265023696, 0.7815212104908785, 0.00017600706350556135, 0.7811959076712557, 0.004772182299202081, 0.089841775187075, -0.13463222331115407, 0.1037582329371469, 0.387566377622479, 1.4113076873507147, -1.5625755300419388, -0.0011460394121054131, 0.0015526317666673065, 0.7815216308679511, -0.00017592259402619733, 0.7811966233323636]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.42436030453276846, 0.25621783906835965, 0.3316215113595368, 0.4231029317078294, 0.2593116342111137, 0.3354361695080158, -0.3323876564804031, -0.4355280545694308, 0.18907637987368886, 0.33353460715291444, -0.43825447744794127, 0.1925119595163956]
+ velocity: [-0.17740403255855033, -0.007391087979481665, -2.0844311312160535]
+ angular_velocity: [-5.060843897088759, 0.07248336513132983, -0.0688936160039453]
+ joints_velocity: [-0.5800953530031085, 0.5896243187494642, 1.6685412270315534, -1.4643108929802733, 0.6840540048440907, 1.8323627234465274, 0.4039191979063441, 0.9936514592536638, -0.6608951471497937, 1.8101461368229597, -1.688393039727281, -0.9461376985248786, 4.1452545718134575, -0.47485502246232275, 0.029770744612818145, 0.004826410390125581, -0.8138895101390332, 0.010171567139190762, -0.011937337510032186, -0.757135250284933, 0.012544706162523168, -0.016962817622463412, -0.5242259458675047, 0.005037479046053019, -0.020580205332692333, -1.0376259880619594, 0.014217899689613015, -0.0004982321238264668, -0.5675864516680231, 0.004843578953725682, -0.0003969782161212659, -0.2067492474434278, -0.4826473527005624, 0.40579293537034655, 0.014384090985740283, 0.03345251289337037, -1.6871389438538733, -0.5462702874829494, -0.004804020322923741, 0.018782060775074872, 0.01509580678397228, 0.0030747399423543533, 0.0011053020311475189, 0.0022254227391247158, 0.49415221860424224, -0.42129204360700345, -0.042644826548624504, -0.018181975099201164, -1.7326196049234908, -0.5376703889921307, 0.004774017733362191, -0.01856470883255415, 0.015386329467386784, 0.0031792928080699202, -0.001072981840678523, 0.0022559229872237256]
+ appendages: [-0.42436030453276846, 0.25621783906835965, 0.3316215113595368, 0.4231029317078294, 0.2593116342111137, 0.3354361695080158, -0.3323876564804031, -0.4355280545694308, 0.18907637987368886, 0.33353460715291444, -0.43825447744794127, 0.1925119595163956, -0.0030226168788018985, 0.34659043224882796, 0.32618083766315187]
+ body_positions: [-0.13815227342749042, -0.0007116269648282258, 0.1213042913458387, -0.03422841667403688, 0.10136710830212135, 0.11463861262888245, 0.23924555895190935, 0.15994037375525255, 0.40748944539960696, 0.33542671276040914, 0.3332001351912871, 0.053487960984349714, 0.442431097396616, 0.35088998762189805, 0.04507005572413594, -0.13815227342749042, -0.0007116269648282258, 0.1213042913458387, -0.033863080005404986, -0.1025014490125453, 0.11608282298579523, 0.2397975285681957, -0.15571570807541057, 0.40978091111802695, 0.3324983734239777, -0.33272235872673883, 0.056705835196156684, 0.43939744636451034, -0.3498184068316801, 0.046026711495764996, -0.13815227342749042, -0.0007116269648282258, 0.1213042913458387, -0.2223686900393193, 0.00010375729566793393, 0.19739163837051854, -0.2781367100330855, 0.0008731812312696299, 0.2963176200503535, -0.3095148276642352, 0.001026087298972174, 0.4059568511176885, -0.2911318697272233, -0.001951189842342993, 0.4952035141264797, -0.262648279164984, -0.0007031731866003461, 0.5806829448353152, -0.3095148276642352, 0.001026087298972174, 0.4059568511176885, -0.30797747664151004, 0.18355972845658833, 0.4860765500105048, -0.18420690889315672, 0.4252850696738375, 0.5398515685753987, -0.06216892768244274, 0.3686755512526964, 0.6617425561919569, -0.0011502059779786161, 0.3403709167766544, 0.7226877814232316, 0.022540756005114303, 0.32936011861445974, 0.74627053858831, -0.008537472054112306, 0.3371092178824157, 0.728586054072383, -0.3095148276642352, 0.001026087298972174, 0.4059568511176885, -0.30703821444028967, -0.18118792866098693, 0.4867775076838336, -0.18204975408887641, -0.42218889553424577, 0.5409851214861177, -0.06017527366698108, -0.3647355991316112, 0.6626446835861189, 0.0007616980033345946, -0.33600907752401943, 0.7234741965690429, 0.02442100071811808, -0.3248349470322675, 0.7470118683775192, -0.006635744754817632, -0.33276979836432874, 0.7293720700133832]
+ body_quaternions: [0.6805025419337923, 0.18979015782416292, 0.19495501487570271, 0.6803591173707757, 0.5026327812182236, -0.0946433490017013, -0.4183750565242658, 0.750576602228509, 0.35263665270517924, 0.3704652040803714, 0.4026310027093858, 0.7591384586436194, 0.6540343786189661, 0.1168463522480344, -0.18459627196523967, 0.7242307490899339, 0.6617519983787098, -0.059146691293366735, -0.36844121406084474, 0.650259204727009, 0.6805025419337923, 0.18979015782416292, 0.19495501487570271, 0.6803591173707757, 0.7465887960552225, -0.4241402467947797, -0.09116369060997703, 0.5043802158769166, 0.7595573490043931, 0.4004527599102574, 0.3758007288025455, 0.34854559656543127, 0.7194899694128202, -0.18188454252641997, 0.1286600043628237, 0.6577984496649809, 0.6422599345038468, -0.3718198211860423, -0.056246661391145425, 0.6678985777695923, 0.6505768360485061, 0.275112816003566, 0.27936548597917404, 0.6503980658883628, 0.6119732629315923, 0.352543585791901, 0.35581334285477606, 0.6120446148936846, 0.5729984344505091, 0.4127430974784015, 0.41528246690405546, 0.5734600267527656, 0.4776108236449396, 0.5189134905491583, 0.5228611744583107, 0.47878270928195277, 0.3988142509936698, 0.5823719196263009, 0.5851354508547201, 0.39925761681924044, 0.3406085880534619, 0.6183870950284507, 0.6208655116830231, 0.34074801077787537, 0.5922444447448203, 0.4374081514843328, 0.38969320142426944, 0.5532267486733287, -0.18245204199083184, 0.8912257058028998, -0.17359860813782602, -0.3772154781316749, -0.7169023841633292, 0.5600205155987019, -0.3767879141166701, -0.17452467294084076, 0.7907208216024728, -0.49447547967263816, 0.17515348391558463, 0.3155563964793895, 0.7909264317251293, -0.49375974259273003, 0.17568053079338225, 0.3158687184471916, 0.9193751310251133, -0.1552825561975336, 0.2827494688506264, 0.22514314126351784, 0.986740962942167, -0.15974555756387085, -0.028698462529492638, 0.00016473923815092312, 0.5524475274995899, 0.3875024670059333, 0.439547935274099, 0.5928247464640642, -0.37735206975864527, -0.17541686211417254, 0.8904335543115068, -0.18428897224484306, -0.17347235526894592, -0.3782499200562516, 0.5582698138200173, -0.7177528508562311, 0.31408151932570794, 0.17633978723099047, -0.4930625135652143, 0.7919257770612956, 0.3143980473548977, 0.1768664182704458, -0.4923465932240429, 0.7921281272927317, 0.2233315755642781, 0.28328584406631735, -0.153518021221377, 0.9199480175900614, -0.002139868005326501, -0.02787780747309255, -0.15874897560039158, 0.9869230018401838]
+ }
+ props {
+ position: [1.0, 2.81657654375526e-20, 1.9957012500000009]
+ quaternion: [0.49662121244722063, 0.5011211970842195, 0.501122835845613, 0.501119567143793]
+ velocity: [1.8431436932253577e-19, -7.914675859144182e-20, -4.414499999999996]
+ angular_velocity: [0.0100143577544876, 0.01, 0.009985621831045642]
+ }
+}
diff --git a/dm_control/locomotion/mocap/test_002.textproto b/dm_control/locomotion/mocap/test_002.textproto
new file mode 100644
index 00000000..ade8797b
--- /dev/null
+++ b/dm_control/locomotion/mocap/test_002.textproto
@@ -0,0 +1,247 @@
+identifier: "cmuv2019_002"
+year: 2020
+month: 7
+day: 7
+dt: 0.05
+walkers {
+ name: "cmuv2019_CMU"
+ model: CMU_2019
+ markers {
+ marker {
+ name: "_left_shoulder"
+ parent: "lhumerus"
+ }
+ marker {
+ name: "_left_elbow"
+ parent: "lradius"
+ }
+ marker {
+ name: "_left_wrist"
+ parent: "lhand"
+ }
+ marker {
+ name: "_left_hip"
+ parent: "lfemur"
+ }
+ marker {
+ name: "_left_knee"
+ parent: "ltibia"
+ }
+ }
+}
+props {
+ name: "cmuv2019_box"
+ shape: BOX
+ size: [0.1775, 0.1275, 0.1775]
+ mass: 3.0
+}
+timesteps {
+ walkers {
+ position: [-0.16228434580157095, -0.0013143659046522152, 0.10500312628961993]
+ quaternion: [0.6973675601678487, 0.11586869848078254, 0.1177806328562669, 0.6974099606844518]
+ joints: [0.16547813534236716, 0.0011641685543727978, -1.043304902546976, 1.773013665955972, -0.37886452015032374, 0.27983384483048845, -0.5748501909854249, -0.16161441838887283, 0.012750300746476017, -1.0532878440531264, 1.7889617256476318, 0.37925256429779786, 0.2970156126979916, -0.5877945825725448, -0.0016635526257656676, 0.001592170709696942, 0.10579869279791383, -0.00030279763464064174, 0.0012202164590354697, 0.15645923709192186, 0.0012555895290633072, 0.0006286985389313479, 0.1883310627400529, 0.0028223250872085655, 0.001908364300817418, 0.4625037442892667, -0.0002893878079503424, -0.0014208292426395794, 0.32489545587589364, 1.3890662821862233e-05, -0.0002397016951764242, 0.21073283309433563, -0.008070627282413298, -0.04476941657369893, 0.19267002348274448, -0.1378345004196423, 0.22950765773332374, 1.378866040988104, 1.5620585854306108, 0.0029175814628601532, 0.0016017458291099024, 0.7815915981859063, 0.0002609321646593133, 0.7813267030913068, 0.004843731971775013, 0.045500602687878706, -0.20071473737605092, 0.1418248898687894, 0.22559138836667023, 1.3781649547822672, -1.5620444471899602, -0.0029720021827282566, 0.001606245674907756, 0.7816000541227026, -0.0002611649893273479, 0.78133306807215]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.43977311233940747, 0.31936023857302803, 0.21542587440169547, 0.44456814093464936, 0.320134412981704, 0.21267836845950522, -0.3459760513931765, -0.5294742553291328, 0.1291036227785771, 0.3511506008607633, -0.5324413390553756, 0.12889160722596044]
+ velocity: [-0.5502981354157529, 0.00021204209834730205, 0.10333740303813647]
+ angular_velocity: [-3.6928857687430128, -0.0642138502844129, -0.04197332530172873]
+ joints_velocity: [-1.4232956684013167, 0.3924022147122244, 5.988093556539502, -5.625777085989636, 0.1944374093412805, 2.2077298705466144, -0.7952736214833999, 1.8539789396528723, -0.31509507655958646, 6.335498870246781, -6.190466771524876, -0.4632926267445042, 2.2855843610393163, -0.46198106882151624, 0.1191296373416708, -0.008945633642468023, -0.6596895959570053, -0.04436207398393202, 0.03707275009031557, -3.9286770620379006, -0.14884956479441858, 0.04462407296769153, -3.944537496793055, -0.013532888592839595, -0.018292238477423863, 0.5634235635931463, 0.03445979254006389, -0.01077681501151364, 0.8210295210466338, 0.008522337923295006, 0.001975509252360874, 0.5308785820008446, 0.5253377011251841, 1.722822544488206, -0.8924746332150322, -0.16709579280684328, -2.3228797915913377, -0.8321106426882738, -0.012884224903582214, 0.06712792286309488, -0.011357205619738906, -0.001338827021690576, 0.0035527380749772854, 0.0024980346859271876, -0.4685953777059846, -1.6945636707464253, 0.8624653897125703, 0.22052738549460527, -2.231887031165224, -0.836314935146574, 0.013654329097135091, -0.07011990636889895, -0.010700807079311184, -0.0010791837898023278, -0.003683501675912779, 0.0027409109587061024]
+ appendages: [-0.43977311233940747, 0.31936023857302803, 0.21542587440169547, 0.44456814093464936, 0.320134412981704, 0.21267836845950522, -0.3459760513931765, -0.5294742553291328, 0.1291036227785771, 0.3511506008607633, -0.5324413390553756, 0.12889160722596044, -0.0002274346099897847, 0.40388632554465265, 0.2498608013262344]
+ body_positions: [-0.16228434580157095, -0.0013143659046522152, 0.10500312628961993, -0.05942653454856292, 0.10071902721537468, 0.12155984159262079, 0.22211796830217947, 0.2022564248242864, 0.3943295298395892, 0.38291664109663887, 0.3499735806724411, 0.05240813690740692, 0.490285892869232, 0.3672137609271936, 0.04951844869980155, -0.16228434580157095, -0.0013143659046522152, 0.10500312628961993, -0.05932340254604104, -0.10315422715810854, 0.12210149734590228, 0.22270650942854034, -0.19628284355122952, 0.3973589088452556, 0.38053327944171056, -0.3471488045084712, 0.05542765443778053, 0.48781665000592483, -0.36421335639648655, 0.04969966974831458, -0.16228434580157095, -0.0013143659046522152, 0.10500312628961993, -0.2682355182480033, -0.000498998957019014, 0.14569757664406638, -0.3618055634590654, 0.00042416833785132986, 0.21004668945754892, -0.4393064578521823, 0.0007796202494424051, 0.29370585307745223, -0.4581219489485899, -0.0021489169881587144, 0.38286394381712596, -0.46269476400421505, -0.0007164753473162552, 0.47284529202112613, -0.4393064578521823, 0.0007796202494424051, 0.29370585307745223, -0.4793782862830066, 0.18482594797643756, 0.3589841616455982, -0.39585737850841063, 0.4439454890282213, 0.40921729870292733, -0.32802760546101434, 0.4266091720384251, 0.5767102636225562, -0.2941128683947749, 0.41794105174270946, 0.6604563770250709, -0.28085687525051756, 0.4145584082587185, 0.6928829426800954, -0.3030253314779195, 0.4148528133312928, 0.6637776396322485, -0.4393064578521823, 0.0007796202494424051, 0.29370585307745223, -0.4793048036786062, -0.182330187757744, 0.3616101261081027, -0.39378273237847194, -0.440385475952749, 0.4139120400704776, -0.3252073714033521, -0.42110069321457005, 0.5808875961194792, -0.290919842016094, -0.41145834433795087, 0.664375006226749, -0.277517931983652, -0.40769811764087854, 0.6966998464278846, -0.29983475187256026, -0.40839758956391464, 0.6677150904269294]
+ body_quaternions: [0.6973675601678487, 0.11586869848078254, 0.1177806328562669, 0.6974099606844518, 0.5013492069769309, -0.12491544492858264, -0.35110828246614695, 0.7808764807993446, 0.41370720258512345, 0.30951776402541886, 0.3831329973497297, 0.7656723911854918, 0.6517958141452932, 0.07016533660368522, -0.17896666866362934, 0.7336279531952207, 0.6449479357858632, -0.11748667125928823, -0.379596627093152, 0.6527981639844498, 0.6973675601678487, 0.11586869848078254, 0.1177806328562669, 0.6974099606844518, 0.7746021374510125, -0.3587733533594273, -0.12239632448197746, 0.5062532462403149, 0.7646435744435849, 0.37953523484201496, 0.3182081626553732, 0.41233090448697385, 0.7309370513807132, -0.17373752951572308, 0.08374855562997573, 0.6546239204203698, 0.6492669641740867, -0.37802924685204387, -0.10947686959774057, 0.6508157287433374, 0.690785153960364, 0.1519535207126444, 0.15511639425374601, 0.689684640139756, 0.6768416444450202, 0.20503142434580665, 0.20898293440249538, 0.6754062751442372, 0.654084639021862, 0.2676370807384784, 0.2716926108319736, 0.6532432955865154, 0.5742367119058869, 0.409920798644667, 0.4146897169161058, 0.574673451815963, 0.5006946838338886, 0.4978022201663987, 0.501769642566828, 0.4997249333145629, 0.4456129247368263, 0.5477692785273317, 0.551474698987181, 0.4441324072685263, 0.6626561987983823, 0.28103400711272203, 0.25800290513992274, 0.644469665678382, -0.11912122136932858, 0.9666959051558296, -0.13349187464587686, -0.18300022125112356, -0.706829814987572, 0.6701361436829117, -0.21940964149426403, -0.05629005945803593, 0.7396521677924552, -0.6317315813809083, 0.02576138225966268, 0.23057803657323928, 0.7398206847191842, -0.6311010288526442, 0.026868437767893195, 0.23163534465768484, 0.9244445425588486, -0.3017042291556246, 0.1130782644673515, 0.20393663709060053, 0.9399374392563369, -0.25462627938586857, -0.22657095415053147, 0.018672730535929813, 0.6461660100505241, 0.2533286778896687, 0.28581679970119184, 0.6607592794827987, -0.18566249044806982, -0.13616568908938023, 0.9651929186393357, -0.12406036661204638, -0.05672169753435545, -0.22314637409867655, 0.6660833739903537, -0.7094513962556289, 0.22993288641240406, 0.028671704510642264, -0.6277063643325695, 0.7431645317713315, 0.23101312492913478, 0.02978998635719433, -0.6270662854175154, 0.7433258817746045, 0.20224795139371785, 0.11554338290759104, -0.2966345497140741, 0.9261497917395667, 0.015286317174635675, -0.22337718825196853, -0.2510672116900548, 0.9417187560466567]
+ }
+ props {
+ position: [1.0, 2.3977131044516938e-20, 1.761487500000001]
+ quaternion: [0.4962453252885034, 0.5012453042143057, 0.5012473257628618, 0.5012432947661197]
+ velocity: [1.380550766298209e-19, -7.914675859144183e-20, -4.9049999999999985]
+ angular_velocity: [0.010015951784962904, 0.01, 0.009984022983674677]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.1984015347585457, -0.0013226182264766773, 0.10613473266517873]
+ quaternion: [0.6958926671469007, 0.11712519384736712, 0.11996501950066925, 0.6983004216478704]
+ joints: [0.10147140459743603, 0.08003883476925437, -0.9408325337851626, 1.4470000336744073, -0.3696381127755711, 0.3392476767444894, -0.5961819093728019, -0.09595635769870423, -0.07981314217735108, -0.9285086322605994, 1.434511427337568, 0.3698570426473498, 0.34307051713217707, -0.59902559584686, -0.002450348618028625, 0.0014273625831709178, -0.025812055515946168, -0.0020388649174705163, 0.0014436059309695906, -0.048968342155591385, -0.0012964587291742632, 0.0011293286817226876, -0.0039139439554281286, -0.0014004660294925013, 0.0014417607228547089, 0.296883807156075, 0.001863582249972586, -0.0014485522822547008, 0.23685162596985065, 0.0006003326672063281, -0.0002991045041744379, 0.1799675807323982, 0.04138794663374596, 0.015041671787103627, -0.0681017691093426, 0.0375408600204265, 0.25347584002772905, 1.3305595347750176, 1.5624479046401825, 0.003267404896054917, -0.0019876057636429014, 0.7806712535382133, 0.00034379550129597027, 0.7809624493233305, -0.040997635800980786, -0.013804545242530317, 0.0661268202066367, -0.0368168408258234, 0.2563338540088746, 1.3307886370348616, -1.5624261759439313, -0.0033448107828945375, -0.001975763436302438, 0.7806737704541826, -0.0003479548263086892, 0.780967755820557]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.44241776727815196, 0.49429398105371797, 0.022220102129838215, 0.449573409434225, 0.48919580752766845, 0.01726380221084492, -0.36830024309341025, -0.6266675804667348, 0.17665720615296707, 0.3697188398227199, -0.6214338259480803, 0.17955191419743058]
+ velocity: [-0.9298460496134969, -0.0018780929935437863, -0.2773112620862802]
+ angular_velocity: [0.6713167605771755, -0.005703238983168889, 0.11242451423404758]
+ joints_velocity: [-1.080401145427269, 0.8128484570713813, 1.1055593917848137, -6.730034691291253, 0.1843710994625545, 0.2630294352283705, -0.23030141036090504, 0.9648319183106413, -1.0512383414145932, 1.4152740928569068, -7.216648964612118, -0.1704402677218062, 0.09469471868860897, -0.1284505732587455, -0.05350772695648592, 0.024080851913374, -2.102365118553055, -0.04761437002790877, 0.02315818939249817, -3.2140151408104565, -0.036438549323496466, 0.024811527313147563, -3.1282762335621856, -0.13064054221747917, -0.0016950908251478043, -6.2906731395977875, 0.04055902333607421, 0.021032812636547247, -3.7450566064786246, 0.01011755544778289, 0.005949147011655262, -1.5359857965744905, 2.0132730289591043, 0.7007827957194773, -6.315600006164353, 4.859017942734184, 1.3552626016177902, -0.9475064232031409, 0.027511862703975657, -0.059224216980194505, -0.11609985315789341, -0.028923563656020418, -0.00165670175143577, -0.014803824499789913, -2.0192688197375173, -0.6725225057842841, 6.380252340939365, -4.94265772581804, 1.4016405849066746, -0.9356433263302579, -0.027962864900119858, 0.060990909665410514, -0.11663377995495253, -0.02926764321377732, 0.001679862515133355, -0.015014913741718714]
+ appendages: [-0.44241776727815196, 0.49429398105371797, 0.022220102129838215, 0.449573409434225, 0.48919580752766845, 0.01726380221084492, -0.36830024309341025, -0.6266675804667348, 0.17665720615296707, 0.3697188398227199, -0.6214338259480803, 0.17955191419743058, 0.0015471952184100306, 0.5090771501617748, 0.04272936068719668]
+ body_positions: [-0.1984015347585457, -0.0013226182264766773, 0.10613473266517873, -0.09582209940921126, 0.10108333800210435, 0.12210805546576284, 0.20340639084362133, 0.23053657172046083, 0.3622842906054647, 0.4459563384900671, 0.3708721556416824, 0.06892059647954246, 0.5526445536367989, 0.384669361563529, 0.08507584939971856, -0.1984015347585457, -0.0013226182264766773, 0.10613473266517873, -0.09500045262140021, -0.10278783591863243, 0.12279886422133053, 0.20848754640884382, -0.2285206324729201, 0.35958305148954334, 0.4529133814939833, -0.36713576511178453, 0.06695928896255082, 0.5598205524059714, -0.3795424937891516, 0.08279154181807394, -0.1984015347585457, -0.0013226182264766773, 0.10613473266517873, -0.3086470216372388, -0.000848483695013676, 0.1331173543874916, -0.4180864698030424, -0.00018270843200365048, 0.1634418963790926, -0.5270598787290252, 0.00010641722211227154, 0.19705860634278405, -0.5983276773933862, -0.002496730399378584, 0.2538555136167873, -0.6647419635930002, -0.0009480244740135478, 0.3147353103456645, -0.5270598787290252, 0.00010641722211227154, 0.19705860634278405, -0.6062662284117203, 0.1809614607022876, 0.2245879937465478, -0.6562013777068821, 0.44700997037671786, 0.28261142016971413, -0.7100769890058923, 0.4440412613884833, 0.45594325385903145, -0.7370146759448137, 0.44255691343567793, 0.5426087887809118, -0.7474125428632263, 0.44210641415838753, 0.576229127942023, -0.7446232596857486, 0.4365414015508297, 0.5401750529241617, -0.5270598787290252, 0.00010641722211227154, 0.19705860634278405, -0.6062368860729553, -0.18020399719824867, 0.22802981414482063, -0.6557797998863397, -0.44496007037132235, 0.2919967435525516, -0.7094661951543436, -0.4381705321278501, 0.46527972722252453, -0.7363092744946842, -0.43477577796631633, 0.5519208372423698, -0.7466679802709878, -0.433585367328653, 0.5855352081053462, -0.7439276898225367, -0.4288248090629969, 0.5493623773532829]
+ body_quaternions: [0.6958926671469007, 0.11712519384736712, 0.11996501950066925, 0.6983004216478704, 0.5106655413270083, -0.13918279081174686, -0.2817447053367982, 0.8002929317804375, 0.47488105945574305, 0.2337501309671161, 0.31863773183263544, 0.7863325323904643, 0.6450328546868919, 0.0020916876986823295, -0.23718941714422354, 0.7264085776006735, 0.6172005224056183, -0.18744405545437226, -0.4400723178689192, 0.6247116104752074, 0.6958926671469007, 0.11712519384736712, 0.11996501950066925, 0.6983004216478704, 0.7978779096313682, -0.277932614701964, -0.1365200282293564, 0.5172103874628423, 0.7839811096252334, 0.3150068518994151, 0.23708903468882575, 0.4795133915101105, 0.7218563141817628, -0.23775252223539753, 0.003466745429597686, 0.6499116720930267, 0.6195695202842162, -0.44015498371387224, -0.18844710557908684, 0.6220007140232975, 0.6981070254978872, 0.10747975563166316, 0.11159350955588962, 0.6990290206456738, 0.7011452010790686, 0.08972351079007168, 0.09507587241824833, 0.7009320060456687, 0.7017189360036197, 0.08789323700003702, 0.0941590713686796, 0.7007134814049115, 0.6815049114039138, 0.19020701468409953, 0.19725534865572453, 0.6785740009195868, 0.6536923325027618, 0.27000700742929323, 0.2754395796334866, 0.6510880035275605, 0.6266142906108825, 0.32782244561077545, 0.33266946569417616, 0.6238733858330088, 0.686355428339249, 0.0844422869211794, 0.09748485511928945, 0.7157390789883827, -0.11782707308594852, 0.9799159495615114, 0.07491642228686209, 0.14236938685931377, -0.6975815514744946, 0.6982109039534378, 0.14682172490864218, 0.06576392428840741, 0.6344606832746866, -0.6909457381008024, -0.3250173552924982, 0.12007225798503757, 0.6335761705972313, -0.6921057889499939, -0.32400832044786015, 0.12078667655752331, 0.8492648966665428, -0.39897158086348916, -0.2536776826016526, 0.23498605555034574, 0.7802779424052187, -0.23631558749015036, -0.5714209166488641, 0.09380517960771438, 0.7165794524160042, 0.09067889306449553, 0.09099705059125111, 0.6855733100927068, 0.1427252597266838, 0.07266737565490211, 0.978672350618838, -0.12864440479701558, 0.06741555350933182, 0.14527969711439526, 0.6904743950233669, -0.7054034750187268, 0.116452604874423, -0.3255925103983536, -0.6839141304694732, 0.6424093477951921, 0.11720623656097633, -0.32456421817351877, -0.6850917935315132, 0.6415372170251697, 0.23188677124641177, -0.2555538278490254, -0.38945492488982186, 0.8539587975257872, 0.08710659698781571, -0.5718314245572376, -0.2278466510368973, 0.7832669827468236]
+ }
+ props {
+ position: [1.0, 2.765257640924279e-20, 1.502748750000001]
+ quaternion: [0.4958693451475818, 0.5013693170977072, 0.5013717617065255, 0.5013668885947519]
+ velocity: [3.6935154009339523e-19, 1.521497048721324e-19, -5.395500000000001]
+ angular_velocity: [0.010017545560173424, 0.01, 0.009982423881844536]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.23031066365886266, -0.001788332752931669, 0.09982016326846974]
+ quaternion: [0.7010230022038314, 0.07604659323911699, 0.07610389582560051, 0.7049764982451666]
+ joints: [0.07548878344793818, 0.04127477550074378, -0.775412935729072, 1.2953709187904257, -0.35905730308949974, 0.3324296592717481, -0.6041874585613841, -0.06935026915961515, -0.041434005996316595, -0.7580760604913198, 1.274176395369968, 0.3588465870647315, 0.33247575558486314, -0.6047478011350585, -0.0024685940828511614, 0.0030316113600123576, -0.0985702365580853, -0.003109288027021297, 0.00324615686492213, -0.12714381204387462, -0.002851790045042325, 0.0028941240187140287, -0.0710823925515788, -0.008212188717267995, 0.00036402629661290415, -0.0075106598819506785, 0.003105651035268043, 0.0001953924755213403, 0.057061914494252956, 0.0006981095613437008, 0.0002821001371307385, 0.1011030729916492, 0.16485376231534446, 0.018276371574148693, -0.27382014732307225, 0.17352670952422536, 0.3451917437713674, 1.336750772027313, 1.563643496743715, -0.00015929524386232465, -0.006062502594770509, 0.7797607093507297, 0.00018026271190695002, 0.7804019958824797, -0.1656913179756747, -0.014455252754075021, 0.27145890343712753, -0.17385332243078508, 0.3517480817311503, 1.3367357119168193, -1.5636433846825082, 0.00015059998768830296, -0.006082931239454871, 0.779747473027511, -0.0001839811013501942, 0.7803980384956002]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.387123924527522, 0.6035853189040323, -0.10591082233193408, 0.3974669536952409, 0.5946568767360542, -0.11527341621161184, -0.36774344852862334, -0.6815916255432184, 0.12758273911121862, 0.36916093289721946, -0.6747467149290903, 0.13127631984261473]
+ velocity: [-0.47771781990876566, -0.003080822814190424, -0.1774298660890389]
+ angular_velocity: [-3.691005014774182, -0.010114706853064762, 0.015227696957121124]
+ joints_velocity: [-0.27810762981202625, -1.6367475531509286, 3.044656673110965, 0.5458534071876693, 0.16321884753188928, -0.16890143168195892, -0.09675435259121921, 0.3235297509914571, 1.5384279869114452, 3.099535681693006, 0.6557243388064936, -0.16408052874662715, -0.20537353041969303, -0.07947743200457213, 0.05461538473983366, 0.00904891564767681, -0.873957975929962, 0.0059936632071259585, 0.0123793913892941, -1.072531013324483, -0.029068089026231793, 0.01066726282198732, -0.8398572252512755, -0.11402925957271584, -0.04953551533225543, -5.071532724491943, -0.0016466745706592038, 0.0348842786135881, -2.815365158691475, -0.008928363353005217, 0.011072892752198285, -1.2361054634736204, 2.002622476151976, -0.2777566574449258, -1.6775901702634801, 1.207842622443348, 2.4925442208729325, 1.28155707892507, 0.017730357463050177, -0.06793127774508706, -0.025968725211685133, -0.0047804431102444825, -0.0039654126809191646, -0.005849163340627465, -2.05371521387254, 0.3159061685908192, 1.6794497569903926, -1.2057494012152266, 2.606031396655095, 1.2774085216302178, -0.01798187723731073, 0.06849500428884403, -0.02622188049044794, -0.004990308684547265, 0.0039285600621086, -0.00595870568501791]
+ appendages: [-0.387123924527522, 0.6035853189040323, -0.10591082233193408, 0.3974669536952409, 0.5946568767360542, -0.11527341621161184, -0.36774344852862334, -0.6815916255432184, 0.12758273911121862, 0.36916093289721946, -0.6747467149290903, 0.13127631984261473, 0.0027446055539231393, 0.5087849844663315, -0.11580727229315284]
+ body_positions: [-0.23031066365886266, -0.001788332752931669, 0.09982016326846974, -0.13061986392004904, 0.10069032191593486, 0.1285349199915373, 0.1819463629403798, 0.23209456897020156, 0.3499298853980398, 0.4548330928295572, 0.37120114497077905, 0.08390544401567707, 0.5596267895483419, 0.3817213864085118, 0.11113292989783702, -0.23031066365886266, -0.001788332752931669, 0.09982016326846974, -0.12948482966044728, -0.10318049083731015, 0.12842871000692496, 0.1886327845249834, -0.23153935524169214, 0.34361615968463854, 0.46483188681709225, -0.3656562469576652, 0.07844918304823567, 0.5696917087644794, -0.37441022353349335, 0.10604476194472143, -0.23031066365886266, -0.001788332752931669, 0.09982016326846974, -0.3436909913539129, -0.0016383246449933914, 0.10503867311458696, -0.4572551463150126, -0.0012516049116756164, 0.10482170685905869, -0.5712041615987082, -0.0011330768172676785, 0.1002414773513602, -0.6623038009037558, -0.003499564553736189, 0.10289093657510318, -0.7520985919612513, -0.0019464646861171385, 0.09553641885204019, -0.5712041615987082, -0.0011330768172676785, 0.1002414773513602, -0.676824589941805, 0.16791231213050017, 0.09740575077450817, -0.8380645735160517, 0.39229349676497915, 0.11463358662247572, -0.95380415632117, 0.35081267052619475, 0.24819634087172346, -1.0116736927008514, 0.33007234880647257, 0.3149774237015424, -1.0342634827328545, 0.3221463277800723, 0.3407752020624609, -1.0148243272882282, 0.3221319554865257, 0.30977899551396804, -0.5712041615987082, -0.0011330768172676785, 0.1002414773513602, -0.677620581711489, -0.1696991931651551, 0.10127574176178038, -0.84041522993468, -0.3923282813715825, 0.12528070464994773, -0.9554309122260316, -0.3462740642624073, 0.2579672571317868, -1.0129384999438864, -0.3232470571845893, 0.3243102410085417, -1.035389490151113, -0.3144365896956919, 0.3499417546724301, -1.0160484849440237, -0.3154330064261666, 0.31890019442429546]
+ body_quaternions: [0.7010230022038314, 0.07604659323911699, 0.07610389582560051, 0.7049764982451666, 0.5233297393045521, -0.13477081368551727, -0.2527377512659902, 0.8025499615743457, 0.49865917473104227, 0.20826957463485324, 0.28265836714222614, 0.7925068196693728, 0.6417224476545811, -0.03192690280647997, -0.276999976213816, 0.714453627768062, 0.6031635828937829, -0.22140623120710887, -0.4770202955596009, 0.5996870939728736, 0.7010230022038314, 0.07604659323911699, 0.07610389582560051, 0.7049764982451666, 0.8004989424484411, -0.24366125896358645, -0.13300684639201654, 0.5311683469786751, 0.788409777042377, 0.28032015525962745, 0.20905434293843228, 0.506089829692109, 0.7090210348848335, -0.2774909445441216, -0.03593186683652402, 0.6472996591475185, 0.594221090506746, -0.4760392292764562, -0.2270593019540026, 0.6072330863704746, 0.7046134012153283, 0.04022116009481825, 0.04247396707688374, 0.7071761981494751, 0.7066947876192382, -0.005898591317532183, -0.001267122921392444, 0.7074928113922931, 0.7070111923535739, -0.03206208037690657, -0.025354968711024727, 0.7060200581076614, 0.7097833384326258, -0.034751384817805375, -0.027997399326381994, 0.7030050493187376, 0.7093996984401365, -0.014632454821565933, -0.007776815587990177, 0.7046115811269701, 0.7089934205843211, 0.02111783219170195, 0.027954553725018778, 0.7043443118657964, 0.6466614051910763, -0.041001049539256325, -0.01672056681843102, 0.7614908821635704, -0.13749962722453618, 0.8915811387296001, 0.2753978454480855, 0.3321640442571946, -0.6604370798650255, 0.6145290375276133, 0.42198631619501675, 0.09002485500628442, 0.526642773592011, -0.6166278458727157, -0.5804282699350511, 0.07430014924541972, 0.5247772592462572, -0.6181751835417081, -0.5806998091138217, 0.0724982918856259, 0.7203497426607481, -0.3723277824947149, -0.5095659376425559, 0.28776870195631976, 0.606794628991284, -0.12775958967386739, -0.7658168357495068, 0.1703007326945723, 0.7627884094182092, -0.02519074006114035, -0.03343514949195637, 0.6452916858683841, 0.33413269240250765, 0.2737181340297082, 0.8890673768722356, -0.15163418596302178, 0.09261402189079632, 0.42188732616430324, 0.6038018557159139, -0.6699679440131181, 0.06898821595938331, -0.5828324514856055, -0.6069552405865788, 0.5358659305073512, 0.06717474502993187, -0.5830851598379942, -0.6085383755541662, 0.5340227480967089, 0.28374788146854696, -0.5137975182455118, -0.35990496831217544, 0.7252362813532405, 0.16191548357417845, -0.7676608902050489, -0.11620923115131293, 0.60907762101702]
+ }
+ props {
+ position: [1.0, 2.970894652972362e-20, 1.2194850000000008]
+ quaternion: [0.4954932721013396, 0.5014932356848609, 0.501496143663144, 0.5014903486166324]
+ velocity: [3.6935154009339523e-19, 1.3371826793987848e-20, -5.886000000000004]
+ angular_velocity: [0.010019139080078537, 0.01, 0.00998082452559597]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.23405276784798174, -0.0017592652686443803, 0.09854878284672659]
+ quaternion: [0.7025457227637255, 0.06083895007647432, 0.06113512603133724, 0.7063926853706227]
+ joints: [0.07232405521627255, 0.026877651952064922, -0.8670051002798487, 1.5417406858852958, -0.36271709051338424, 0.3361856077077121, -0.6057481278231732, -0.06412928630021078, -0.027803294336940506, -0.8514611927441305, 1.5388545548513672, 0.3628351335405838, 0.3364145415154504, -0.6058406341012067, -0.001212885896015701, 0.0021057527613514327, -0.14040059776820432, -0.002120820069779113, 0.002112273123147233, -0.1526726223266465, -0.002105390344933425, 0.0017730906925126294, -0.06974903080065342, 0.0033255905830198926, -0.0033897169776724574, 0.13763432502141146, -0.009998175456095418, -0.0052607101074501755, 0.0605710593407709, -0.0027298854084947225, -0.0038727833826666704, -0.04588463572650255, 0.20980016884877403, 0.04022271401369089, -0.31684173377207975, 0.3008464635881334, 0.3599665814223977, 1.3902027332031888, 1.5648851229814709, -0.005374835737951142, -0.0052759209659549765, 0.7803268391806134, -0.00029096667729687915, 0.7802947826027866, -0.21397856799081372, -0.036775174141480876, 0.35113061864492795, -0.30715474511656327, 0.34039954100074893, 1.3905259502553438, -1.5648319262360968, 0.00524221202198515, -0.005006094502411115, 0.7803672338787006, 0.0002775925425636951, 0.7803194399058783]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.3602665733985246, 0.6204997534659197, -0.18559997490729144, 0.3709069992119262, 0.6135158060975043, -0.18109028652164505, -0.3361798549341376, -0.6228112592447421, 0.09662308411318316, 0.34005481917095964, -0.6194030829604678, 0.10472119144230847]
+ velocity: [-0.061294657150756666, 0.00048430363359450884, 0.014409362069782924]
+ angular_velocity: [-0.021127751775553816, 0.004076321043339142, -0.006716383258153461]
+ joints_velocity: [0.05930609341388562, -0.36832591705527895, -3.155759328853605, 5.712165309119618, -0.14358000238396854, 0.2587187842978693, 0.023751397750099712, -0.034156930568173746, 0.4266635521734691, -3.250319259913174, 6.165339221780599, 0.1587795024707462, 0.29014025582247205, 0.032416958739604, -0.009225049055202179, -0.02529773703172931, 1.5167172711651598, 0.00036950636848603023, -0.020445069766201245, -0.2204987411459373, -0.007060499658008046, -0.021168394577560243, -1.0821701155144499, 0.12410307239699746, -0.05808443926616059, 1.5200707297324985, 0.06364941879838991, -0.028430770172441895, 1.136525873418136, 0.003218651283471159, -0.03770652861962861, 1.2006546590117146, -0.6225887197437707, 0.28906363108481425, 0.0985616571416499, 0.35798442133923025, 0.5238583914240015, 0.1342568224091573, 0.02694521298960307, -0.10890288823267481, 0.0001506063213490155, 0.008727521660703822, -0.010019939672185275, -0.002848408923995372, 0.6431149579494972, -0.3654341316299582, -0.21855184601449612, -0.9574671216566029, 0.7652429267731806, -0.7334320427844019, -0.026053835845285537, 0.11061926107712633, -0.01088774698849533, 0.009823925262727878, 0.011365299527868687, -0.002750620130833337]
+ appendages: [-0.3602665733985246, 0.6204997534659197, -0.18559997490729144, 0.3709069992119262, 0.6135158060975043, -0.18109028652164505, -0.3361798549341376, -0.6228112592447421, 0.09662308411318316, 0.34005481917095964, -0.6194030829604678, 0.10472119144230847, 0.0007086718071675362, 0.50675469240825, -0.12074727797901184]
+ body_positions: [-0.23405276784798174, -0.0017592652686443803, 0.09854878284672659, -0.1356713501041957, 0.10071434407327845, 0.13148695978674044, 0.14895665308127157, 0.21701035253412773, 0.3950065819095304, 0.3922670755504607, 0.3417182002411995, 0.09528203955914838, 0.49936452447810603, 0.3532982792278413, 0.11044015762220499, -0.23405276784798174, -0.0017592652686443803, 0.09854878284672659, -0.1345589607242152, -0.10315662088679864, 0.13147637138026963, 0.15576617973610463, -0.21647812842193806, 0.3900300442833447, 0.3979225247003999, -0.3344952363800789, 0.08668358052986463, 0.5053811689481363, -0.3442170735910224, 0.10053478975658489, -0.23405276784798174, -0.0017592652686443803, 0.09854878284672659, -0.3474681945515887, -0.0017275559884511746, 0.0941562851695784, -0.46031986359088495, -0.0015613081842191746, 0.0814485286701189, -0.5730961239972646, -0.0017336487939531523, 0.06451170420852195, -0.6639976836230818, -0.005395533215336135, 0.07044982175310915, -0.75401525606653, -0.003886020404022323, 0.06668636939417963, -0.5730961239972646, -0.0017336487939531523, 0.06451170420852195, -0.6849067651804135, 0.16219777106630212, 0.04540856692916838, -0.8715764468826352, 0.36567849778285055, 0.025607798330535985, -1.002779476050829, 0.30008466166071934, 0.13255887637967548, -1.0683807015395859, 0.26728788813040494, 0.1860341797461321, -1.0939585302309096, 0.25450687471620004, 0.20655407660023947, -1.069118514605671, 0.25926157247115333, 0.180115236528161, -0.5730961239972646, -0.0017336487939531523, 0.06451170420852195, -0.6856082360085167, -0.1654504755931131, 0.04784182038525146, -0.8752419161628417, -0.3655257895323376, 0.02232745673949518, -1.0070517233681902, -0.29911029367224884, 0.12801792901655246, -1.0729563365385375, -0.2659026920834169, 0.1808629322746116, -1.0986404903208253, -0.252956705861405, 0.20114522693244724, -1.0734647718991233, -0.2577887792923313, 0.17504005388821506]
+ body_quaternions: [0.7025457227637255, 0.06083895007647432, 0.06113512603133724, 0.7063926853706227, 0.5144036694740793, -0.1708659499471406, -0.3029751331725294, 0.7838365650180856, 0.48803660362059026, 0.2358528811488541, 0.32882037802591735, 0.7733504063328319, 0.654822745742129, -0.009279207106049918, -0.2237239698307555, 0.7218508525293779, 0.6222496504918367, -0.2041673443246124, -0.4288433846729863, 0.6222655537601606, 0.7025457227637255, 0.06083895007647432, 0.06113512603133724, 0.7063926853706227, 0.7815910767725152, -0.2944635358591785, -0.17040424731974377, 0.5228470208907019, 0.7662894331294906, 0.3322557597923528, 0.2413542541807311, 0.4941201663034547, 0.7166872619428273, -0.21720022414170498, -0.009625657341793449, 0.6626392517217983, 0.6192633194429746, -0.42110510563371467, -0.20685853399366358, 0.6295974730855253, 0.7053905634011617, 0.010606693076871591, 0.012236802310658202, 0.7086338347805408, 0.7048227755430814, -0.044033484366390374, -0.04102974250028782, 0.7068256274079641, 0.7036178967810292, -0.06919165734696732, -0.06504389604876804, 0.7042042753845067, 0.7053555530671948, -0.019648352770908775, -0.017451216376464053, 0.7083663889807614, 0.7090504230212762, 0.003765566810059646, 0.0019395308099864096, 0.7051450604947789, 0.7099454322084735, -0.011162832864382414, -0.015584174303208145, 0.7039957442769473, 0.6270429983350347, -0.0911719167882481, -0.044840677712975896, 0.7723302877981723, -0.1293873714274127, 0.8307917877646112, 0.34023063090752753, 0.42105466544173586, -0.6314596927893418, 0.5551709132189871, 0.5309633653799182, 0.10546003092365305, 0.4705189817667833, -0.5634229251933264, -0.6776985515628079, 0.04348756573340406, 0.4691510383394198, -0.5628391924247137, -0.6793192743707944, 0.04043105445171803, 0.6479629013871157, -0.3421033933334346, -0.6128888073534652, 0.2957645288343204, 0.5140078488049894, -0.06965357890122431, -0.8337848730290353, 0.18906902395970027, 0.7733374164168655, -0.050347479371691164, -0.08629422846042653, 0.6260732208176809, 0.42447014853539455, 0.3508803462772641, 0.8253342828504623, -0.1246250261801392, 0.10117484705566612, 0.5413460766686827, 0.5539288174087514, -0.6243964612670619, 0.04731334158841888, -0.6858835018732365, -0.5611015346087844, 0.46096674221196343, 0.04438432095380477, -0.6874651002226818, -0.5604542105581906, 0.45968777001012684, 0.3025313754371257, -0.6189127554404039, -0.34348401338975154, 0.6383106614924717, 0.19832439337523367, -0.8371283375346518, -0.06747987480267527, 0.5053019374447626]
+ }
+ props {
+ position: [1.0, 2.6676794454005827e-20, 0.9116962500000007]
+ quaternion: [0.4951171062266801, 0.501617059926224, 0.5016204716192574, 0.5016136748187426]
+ velocity: [4.61870125478825e-19, -1.2540605128415669e-19, -6.376500000000006]
+ angular_velocity: [0.010020732344637632, 0.01, 0.009979224914969742]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.24433438060195692, -0.0017962076433118187, 0.09340703887360832]
+ quaternion: [0.7047259537476365, 0.026427032071557645, 0.026604497412898796, 0.7084879270727061]
+ joints: [0.06409165072029244, -0.035059070540012234, -0.9058577188812151, 1.7385309895529086, -0.3681612866370557, 0.34716370049277645, -0.6038229347507793, -0.05626202502615892, 0.036744820817569405, -0.8956325815948636, 1.7546736738730426, 0.3689918402439796, 0.34846994851185265, -0.6035879022396061, -0.0033135447008843352, -0.00022288191790906093, -0.05842011503985777, -0.002709054066976314, -0.00023700354639417465, -0.09426782856477853, -0.0015695802282039604, -0.00028647086741966093, -0.03765392681606732, 0.007171178403060918, -0.003304022000926985, 0.13760415740225165, -0.0019438989488972166, -0.0028709028867526967, 0.10451889990663134, -0.0003680074744907688, -0.0021248753105114564, 0.05729365579379155, 0.1287087525586643, 0.035508115292432404, -0.20738282208191444, 0.13973088975915443, 0.32550717625775527, 1.4010567053520047, 1.5647034065604286, -0.0057187604227979605, -0.0034370066938186426, 0.7807188489697836, -0.00031579014848522003, 0.78045887562362, -0.1310651814764284, -0.03555729171087928, 0.21617277263862042, -0.14808608492483905, 0.3284650121704502, 1.3955077180131692, -1.5647443807771622, 0.005763831288042127, -0.0038391942690191947, 0.7806448353478906, 0.000320439017499987, 0.7804253565823642]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.406343654617665, 0.5797982117020389, -0.05668884818236977, 0.4178359599815667, 0.5701078919411161, -0.05414573806424066, -0.3059272335286657, -0.5680426039164983, 0.05834454907036289, 0.3112650199071059, -0.5709218986534508, 0.0680035970186459]
+ velocity: [-0.2736091482630564, -0.0018926771508668344, -0.14443301839386802]
+ angular_velocity: [-2.865473353057027, 0.0022616786420445126, 0.003238510017778679]
+ joints_velocity: [-0.44469986332239236, -1.360158011109948, 0.4804215866619491, 2.4291766216146673, -0.08689819147389787, 0.19014085996236138, 0.0410689958155338, 0.4060851603179207, 1.400645752910102, 0.3696924432620612, 2.718297868949766, 0.09616004746405087, 0.2042590547932576, 0.04558762294324819, -0.035872274738900735, -0.027623295475097568, 1.257032720078592, -0.02739400663922532, -0.038686618191001866, 1.9265791119343816, -0.02161158235959277, -0.03162687068435193, 1.950889180283002, 0.04120850628728974, 0.008919750172399897, -0.5176971445642333, 0.14669479746231393, 0.04794537495044989, 0.48746156489385084, 0.05406902383355825, 0.03983893224906673, 1.5634841274543452, -1.9759088648527319, -0.7103091008589627, 4.5160927104459265, -4.094916871330123, -1.943703695667663, -0.6566523012026193, -0.02167585687639392, 0.06151927917676751, 0.04687938211796277, 0.00781891742854798, 0.004434004670238334, 0.006972562094295739, 1.9609341004683964, 0.6842364846485022, -5.039468367523984, 4.422851163484357, -1.742664475544493, -0.3072156398091558, 0.020992848690684646, -0.06012817345733686, 0.04547237791946517, 0.00704595035054496, -0.0045267086625128735, 0.006632383514519693]
+ appendages: [-0.406343654617665, 0.5797982117020389, -0.05668884818236977, 0.4178359599815667, 0.5701078919411161, -0.05414573806424066, -0.3059272335286657, -0.5680426039164983, 0.05834454907036289, 0.3112650199071059, -0.5709218986534508, 0.0680035970186459, 0.0007787419167942401, 0.5187114840251598, -0.041886836631411364]
+ body_positions: [-0.24433438060195692, -0.0017962076433118187, 0.09340703887360832, -0.1496236787747478, 0.10064883958482707, 0.13582829211761194, 0.10642204141124562, 0.19211652874432733, 0.4359194516662851, 0.3284125253343738, 0.3125249370198267, 0.11841535639640116, 0.4364863240556138, 0.323346051203048, 0.12447240659882033, -0.24433438060195692, -0.0017962076433118187, 0.09340703887360832, -0.1485378694632824, -0.10322226867915088, 0.1358387494846838, 0.11211245953849153, -0.1909510336478475, 0.4330684970089612, 0.32810457349819333, -0.3046781922607005, 0.10903092148711246, 0.4364690825264959, -0.3136861333263399, 0.11214869636239519, -0.24433438060195692, -0.0017962076433118187, 0.09340703887360832, -0.35766633196785796, -0.0014580506987222617, 0.08723383021349383, -0.47095624669520914, -0.0008644516054764612, 0.079356217648584, -0.5846827110920729, -0.0006459176920526978, 0.07089370182326198, -0.6748835230678998, -0.004459084431037592, 0.08358375533840301, -0.7647230529580407, -0.0037891722328370738, 0.09051308704947664, -0.5846827110920729, -0.0006459176920526978, 0.07089370182326198, -0.6836291471355072, 0.1722637301702688, 0.06368841046686588, -0.8191141214743478, 0.41298384623298956, 0.08211920264067113, -0.9358225072806577, 0.37238223755523403, 0.21510785385296036, -0.9941764430262554, 0.352081522678741, 0.28160188642929007, -1.0169585399247325, 0.34413752782580964, 0.3072244311515655, -0.9978950303886527, 0.3446790667993705, 0.27600047991108484, -0.5846827110920729, -0.0006459176920526978, 0.07089370182326198, -0.6848086991985238, -0.17286517931199064, 0.06345245287632761, -0.8245781578373335, -0.4112366505405023, 0.08035176603305003, -0.9419257568707188, -0.36982716741225363, 0.21252632865442256, -1.0005992978213991, -0.34912251709060055, 0.27861331872907275, -1.023512920551966, -0.34102647987360335, 0.3040704565513467, -1.0041663044937255, -0.3416379755537502, 0.2730224242225791]
+ body_quaternions: [0.7047259537476365, 0.026427032071557645, 0.026604497412898796, 0.7084879270727061, 0.5099101190482822, -0.19699447912192708, -0.3564451512451368, 0.7577147879258255, 0.47956499614996473, 0.2623596172813894, 0.34873896320993675, 0.7612923099748912, 0.6663749335304293, 0.017762095935007303, -0.19348388455781235, 0.7198561955882097, 0.6415161905618567, -0.1811850472215961, -0.3987787704394904, 0.6297653913617989, 0.7047259537476365, 0.026427032071557645, 0.026604497412898796, 0.7084879270727061, 0.754627912224856, -0.35097643906748544, -0.19712600888380977, 0.5181636709884272, 0.7522739612106265, 0.3559938678879317, 0.2724849944309251, 0.4828086382001068, 0.7146888086521895, -0.18159463313941826, 0.022578173363057924, 0.6750803819471144, 0.628412288625051, -0.38581692485264485, -0.17909879492819378, 0.6512809821095216, 0.7063734676792066, 0.005833955471602634, 0.0059004281355917385, 0.7077906993386005, 0.7068251199728337, -0.027423838416216414, -0.027485414370711614, 0.7063219767632753, 0.7067359393062626, -0.0406132717670596, -0.040890095784080475, 0.7051261407124533, 0.7051906660547677, 0.008910880899803804, 0.006845984337763539, 0.7089286658114976, 0.7044065760890351, 0.04677987273389401, 0.042828239932621213, 0.7069573968308113, 0.7029314814559476, 0.06768469891062408, 0.06231519751070362, 0.7052821634429209, 0.6605003906382099, -0.056451201887711036, -0.026471840008533712, 0.7482324087198894, -0.12249234543057493, 0.9031536721912973, 0.22413470176135583, 0.345069131022524, -0.6758380659890866, 0.6115012991334028, 0.39379041310105595, 0.11932384618336517, 0.5491948671119979, -0.6214142777191163, -0.5569676640775515, 0.04490339137147481, 0.5482550457696553, -0.6207622515166166, -0.5588158870100924, 0.042372589078998435, 0.7432240720739745, -0.36544130419213106, -0.5006538290822846, 0.25182608146373686, 0.620183406313903, -0.13367246793314802, -0.7621652434673807, 0.1288734088159943, 0.7505032115451999, -0.026156549648129815, -0.05681312031162349, 0.6579004740318273, 0.34699319606640333, 0.23279111581455422, 0.9000628574355642, -0.12365626125014764, 0.1163286112982091, 0.40132734259204517, 0.6102577858549807, -0.6730449116402368, 0.04745599840156722, -0.5634995100116182, -0.6194310609090441, 0.5445377775815679, 0.04480129337343892, -0.5653693575847102, -0.6188489452295877, 0.5434853416648385, 0.25654709422208444, -0.5057996192464815, -0.36551421710871723, 0.7380716026970019, 0.13464788907178665, -0.7655422945555836, -0.13104887034667262, 0.6153382279656813]
+ }
+ props {
+ position: [1.0, 1.8324823718625826e-20, 0.5793825000000004]
+ quaternion: [0.49474084760052556, 0.5017407897722723, 0.5017447455614047, 0.5017368671881023]
+ velocity: [4.61870125478825e-19, -1.716653439768715e-19, -6.867000000000009]
+ angular_velocity: [0.010022325353810098, 0.01, 0.009977625050006616]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.2541346359316792, -0.0018524177615704155, 0.08782510184330196]
+ quaternion: [0.7050697928315172, -0.013546101905684674, -0.013507440767518605, 0.7088798483548495]
+ joints: [0.03813766647066743, -0.0564510205218579, -0.9107413020532982, 1.7909881479665546, -0.3726955299912399, 0.3535741252587462, -0.6021504130906691, -0.03316862230085105, 0.05979787450506428, -0.9059414932864317, 1.81325893518048, 0.373624292715183, 0.3553478375791571, -0.6017669904055991, -0.004139387649482148, 0.00014775634207823488, -0.0022760473779450405, -0.0036688029971609346, -0.00019620064954660674, -0.01295022729932099, -0.002159627434307131, -0.00026227455701723144, 0.04092009987883586, 0.005470001576975803, -0.003243872881624933, 0.0893805277158246, 0.001535017399805044, -0.0013268279346219315, 0.10734835419577192, 0.0009327450969040318, -0.0010034103255379693, 0.10452946917213293, 0.012664309912073419, 0.0012743721318287824, 0.03621995770922891, -0.02473782816511619, 0.22673708058581582, 1.3491939726104591, 1.563538645391736, -0.0011784073029246435, -0.00253908359772317, 0.7808490376726105, -2.082089280181707e-05, 0.7807625879270118, -0.014361666371187508, -0.0011922669333981244, -0.037730917882390566, 0.028115627494012058, 0.22954191694630338, 1.3514349188056367, -1.5635350408607585, 0.0011481048573092594, -0.0026004905715598507, 0.780834933848696, 1.853176838777804e-05, 0.7807580860288588]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.4471835099133121, 0.44734122412522503, 0.07166990963803063, 0.45658868768860045, 0.437770125567797, 0.0709697383122008, -0.2883660548410366, -0.55540607688171, 0.04851498948855563, 0.29289614522833607, -0.5611543087222017, 0.05657728769882066]
+ velocity: [-0.08950985219845241, -0.0005482472090103179, -0.06162316826474175]
+ angular_velocity: [-1.2464993689838004, 0.000682287127424781, 0.0014991796852603808]
+ joints_velocity: [-0.4080133416615965, 0.484610314510626, -1.2117513752115503, 0.0803623641857967, -0.09204122351936689, 0.044631652837113234, 0.02068990784611251, 0.3343943689515231, -0.4584000660982136, -1.3151610485537735, 0.06030586071511937, 0.08872270739514644, 0.049921550541194154, 0.02259520906376473, -0.005410029781633, 0.02368250376974116, 0.7945435135623722, -0.0029863255683662375, 0.02381073828450555, 1.1510544163155079, 0.007676118769343171, 0.01896906079819194, 1.07014501779638, -0.08643269376036002, 0.005493195703026822, -1.294717934916928, 0.010376063163722098, 0.02552291159974483, -0.3069801552735282, 0.0032697745720032803, 0.014937985634589908, 0.503112211256042, -2.368922551080678, -0.38798579216827017, 4.1596672834156525, -2.3956500912268233, -1.7551260167859783, -0.37488117413542127, -0.022858498972809844, 0.09572241835453783, 0.010254429989581206, 0.0022841277934428006, 0.005687681376424183, 0.006063272147506719, 2.4058900409054016, 0.40302639485085967, -4.1694525056298835, 2.4230211726270294, -1.8328810476836679, -0.38760442613324686, 0.023886233723389792, -0.09739618947334919, 0.017116087919471303, 0.003709278177174932, -0.00576276296813613, 0.006727030085484906]
+ appendages: [-0.4471835099133121, 0.44734122412522503, 0.07166990963803063, 0.45658868768860045, 0.437770125567797, 0.0709697383122008, -0.2883660548410366, -0.55540607688171, 0.04851498948855563, 0.29289614522833607, -0.5611543087222017, 0.05657728769882066, 0.001753015748532538, 0.5179579821668641, 0.034969456039203164]
+ body_positions: [-0.2541346359316792, -0.0018524177615704155, 0.08782510184330196, -0.1648362497171066, 0.1005753918943236, 0.14073076786820882, 0.05908468684200807, 0.178878578830905, 0.4689204662304123, 0.30285900727039855, 0.2940603886191349, 0.1657801659689725, 0.4106902020887069, 0.30211406955054815, 0.17766721719587653, -0.2541346359316792, -0.0018524177615704155, 0.08782510184330196, -0.1637381475091268, -0.1032956482593207, 0.14076292703495175, 0.06322476824702059, -0.1775358397516114, 0.46780525130455963, 0.3005542522425185, -0.2872239554145134, 0.15759557454098078, 0.4087844304617862, -0.293759786259816, 0.16638430466240317, -0.2541346359316792, -0.0018524177615704155, 0.08782510184330196, -0.36692887883689507, -0.0014287514968385501, 0.07519049341966119, -0.48037498466919343, -0.000641388107194977, 0.07005482901119066, -0.5943688220354051, -0.00016567636482871419, 0.07330280934366368, -0.6837588211615904, -0.00361837259473261, 0.09088891536118529, -0.7730531196973999, -0.0028920191385958098, 0.10295565522128816, -0.5943688220354051, -0.00016567636482871419, 0.07330280934366368, -0.672509180295641, 0.1829793845745375, 0.08289607873307639, -0.6967520438896179, 0.45236738072136384, 0.14192520498242275, -0.785620858180675, 0.4506900386079174, 0.3002124112434976, -0.8300550695109067, 0.44985137124708285, 0.37935566560088535, -0.8473534723431013, 0.44958455400555525, 0.410004317437055, -0.8376708938292229, 0.44503882092115593, 0.37501532237563634, -0.5943688220354051, -0.00016567636482871419, 0.07330280934366368, -0.674484523596411, -0.18246565172281745, 0.08269712150393142, -0.7014749013909255, -0.45144330148316614, 0.14240130967284167, -0.7897962417197112, -0.4480429134289936, 0.3009670610178527, -0.8339567172751222, -0.44634272689438936, 0.3802495873034568, -0.8511499341673802, -0.4457435874860306, 0.41095268350278386, -0.8415284848359043, -0.4414567690611341, 0.37591418051087777]
+ body_quaternions: [0.7050697928315172, -0.013546101905684674, -0.013507440767518605, 0.7088798483548495, 0.5017679148292739, -0.23727940865619732, -0.38928795659044585, 0.7351070185553144, 0.4988726705058406, 0.24330765862734985, 0.3304074391999103, 0.7633861185514218, 0.6699930952158967, -0.003078039077048304, -0.21256472518799535, 0.7112777345342778, 0.6389428103201402, -0.20162417291024107, -0.4139306140727216, 0.6162476975793436, 0.7050697928315172, -0.013546101905684674, -0.013507440767518605, 0.7088798483548495, 0.7315420407280105, -0.38667791609759905, -0.2377254116555334, 0.5087367300511366, 0.7554069595041697, 0.33768904880488787, 0.25406014350898803, 0.5007792680713812, 0.7069818862641152, -0.20084288295474723, 0.003338482759780187, 0.6781058939381975, 0.6156981175848415, -0.4013440907022592, -0.19777775937028483, 0.6486314105555859, 0.7065205430184639, -0.014374425683760454, -0.014288300113329905, 0.7074022495433758, 0.7077083103590557, -0.018861776960737854, -0.01895586737705004, 0.7059984815214532, 0.7087030428392512, -0.0042506161882709115, -0.004636974981188724, 0.7054788641734281, 0.70619601908626, 0.028462103979148026, 0.025838386278668836, 0.7069720426284999, 0.7031007400991736, 0.06676768411847293, 0.06326785350378342, 0.7051160219016968, 0.6983357464820067, 0.10377452097812628, 0.09964954516516561, 0.70116189437929, 0.7042244575559917, -0.004732263638510318, -0.004161245501976852, 0.7099494370000566, -0.11502081554607181, 0.9605416533782589, 0.015314341206335173, 0.2527754241810169, -0.6897641424923484, 0.6783033861151794, 0.16983900779205705, 0.18784210270513552, 0.6210562507568164, -0.7038245886275387, -0.3448318347135043, -0.0033298058923493246, 0.6201604041335231, -0.7044091772973673, -0.34524150074052845, -0.004134035633539028, 0.841576597270325, -0.41537940108136057, -0.3208345776173143, 0.12756942380253666, 0.7257550088616176, -0.2622135505964812, -0.636021077673685, 0.0009538085052120948, 0.7137477443045722, -0.0037996938420436998, -0.005092867930733863, 0.7003740304459339, 0.25056545537535685, 0.02009962502795535, 0.9607348843530078, -0.11747952857195451, 0.18293434743155254, 0.1724008900311027, 0.6761430484488012, -0.6925630192846264, 0.0008406388972089784, -0.3480393314705321, -0.7004524900106291, 0.623084445571946, 2.989610907853116e-05, -0.34844170397230034, -0.701062152727912, 0.6221737989112113, 0.13263567402479395, -0.3222101797536604, -0.41152392723095593, 0.842161763175884, 0.004435421688979312, -0.6363992148292968, -0.25690931952488133, 0.7273059658358082]
+ }
+ props {
+ position: [1.0, 1.1591928227490846e-20, 0.22254374999999973]
+ quaternion: [0.49436449629981666, 0.5018644251735016, 0.5018689654761259, 0.5018599257117691]
+ velocity: [6.006480035569696e-19, -1.716653439768715e-19, -7.3575000000000115]
+ angular_velocity: [0.01002391810755534, 0.01, 0.00997602493074737]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.25501047276670497, -0.001861994562340482, 0.0873736764329003]
+ quaternion: [0.704947045853522, -0.017326790429849574, -0.017353867329018394, 0.7088358682830036]
+ joints: [0.027223464329484047, -0.009646832113460915, -1.0108634741962454, 1.7229212363782513, -0.37532769750045364, 0.3489137383126649, -0.6024966043979272, -0.02573317486058536, 0.012645268825987523, -1.0096384162153798, 1.736216058877133, 0.37590837082443024, 0.35057882300542786, -0.6020926907177797, -0.0033348789685529196, 0.0010565342501578793, 0.01615501612766379, -0.0033194362358308356, 0.0008499589492389699, 0.020168252196274564, -0.0020184231882263855, 0.0005798318333381302, 0.0745728563864633, 0.00032008506043644676, -0.002526444423228021, 0.024870402188710283, -1.370661548944017e-05, -5.287561992319382e-05, 0.08102728354958716, 0.00025149785282327483, -0.0003070745595870483, 0.11333553780594197, -0.014012742056799981, 0.0023423706219093607, 0.20132782916243233, -0.099248246353559, 0.132238318006801, 1.3681183505657977, 1.5625710704236064, 0.002322850368296049, -0.0009447894235096916, 0.7812978783692536, 0.0001383944417840305, 0.7811344305271054, 0.01426152309830514, -0.001599109881034606, -0.20080354048505203, 0.09909631269440417, 0.13395232966708132, 1.3682303167102385, -1.5625632716919036, -0.002341867646798085, -0.0009135004337726153, 0.7813037254921582, -0.00013956787249009425, 0.7811375445642491]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.4489235731924273, 0.3732851017386872, 0.09205256886891969, 0.4562839956889576, 0.3653237793566859, 0.08898182046243433, -0.2909748361861679, -0.5748716026921781, 0.12277451657216823, 0.29312871054738054, -0.5777440020037057, 0.12747169772330466]
+ velocity: [-0.014083535291031117, -0.0001646000859212195, 0.0033414839495510473]
+ angular_velocity: [0.014799436193169986, -0.0018898060803229492, 0.0021060590573945847]
+ joints_velocity: [-0.10400185894957331, 0.822544598796172, -1.7550025514542795, -2.712975768890385, 0.0029036581446809494, -0.18796449362672255, -0.029261709757098013, 0.05184984638644144, -0.8628118337196455, -1.7891733623229764, -3.00492424056251, -0.011836443727811258, -0.19664769008881938, -0.03019935213522205, 0.025491169668107362, 0.010217403510241365, 0.4812422395859041, 0.015353625048488015, 0.014938459240804124, 0.41285866143342276, 0.0021637972605188568, 0.012944459423391471, 0.26615830944223184, -0.10353160253805578, 0.019722704941123968, -0.9777529996188893, -0.05217063760388057, 0.02230114139575289, -0.5120134808982083, -0.022537815749312535, 0.01197313302764748, 0.020298953338430285, 0.475063825998685, 0.18803936409420519, 1.3783272400163527, -0.5607156878418378, -1.1642833029540127, 0.5138041238532257, -0.01658164049703155, 0.05283705710261086, 0.04424929195700539, 0.012696337431085027, 0.0016390492047829285, 0.007924057697219582, -0.4634798822306857, -0.17836994791617067, -1.3443113024145843, 0.5019488719785425, -1.150631005058816, 0.4911204921197893, 0.016150170054473188, -0.05158612916890161, 0.043350067381204976, 0.01248772829024987, -0.00157804235846651, 0.007782273030294061]
+ appendages: [-0.4489235731924273, 0.3732851017386872, 0.09205256886891969, 0.4562839956889576, 0.3653237793566859, 0.08898182046243433, -0.2909748361861679, -0.5748716026921781, 0.12277451657216823, 0.29312871054738054, -0.5777440020037057, 0.12747169772330466, 0.002416020976409889, 0.516559479559114, 0.053899722104603476]
+ body_positions: [-0.25501047276670497, -0.001861994562340482, 0.0873736764329003, -0.16629980529354116, 0.10056956864197507, 0.14125187058374467, 0.01746190379089463, 0.17757201242559303, 0.49378940184918063, 0.3141679877627255, 0.2944160768779505, 0.2429911271633971, 0.4179953197373091, 0.29998387084487665, 0.2749694814791135, -0.25501047276670497, -0.001861994562340482, 0.0873736764329003, -0.16517872464432648, -0.10330134803430631, 0.14127156197062005, 0.020304060474177155, -0.17710483760068135, 0.49359191956796433, 0.31474131429508767, -0.2896936123705704, 0.23821517511532753, 0.41913885870032497, -0.2940404723602184, 0.268480466618086, -0.25501047276670497, -0.001861994562340482, 0.0873736764329003, -0.36789828172477274, -0.0015466652607407753, 0.0756009447294341, -0.48146029076327945, -0.00089540951918441, 0.07509115297388265, -0.5948976999018801, -0.0005412290888340243, 0.08680436771283243, -0.6841221442271945, -0.0035716029977832945, 0.10528623852467749, -0.773596882330247, -0.0022948762254910806, 0.11588089573177812, -0.5948976999018801, -0.0005412290888340243, 0.08680436771283243, -0.6674631880616274, 0.18457120234186378, 0.10121607861240227, -0.626761241563053, 0.452391504356611, 0.15829220347172304, -0.6888443850634017, 0.46331921968525536, 0.3285319979582391, -0.7198858200183542, 0.4687830532712342, 0.413651520091779, -0.7318574332781248, 0.47095858175392796, 0.44667565283042143, -0.7288845593604476, 0.4656724139605469, 0.41059429695110594, -0.5948976999018801, -0.0005412290888340243, 0.08680436771283243, -0.6686123052864782, -0.18512018617925236, 0.1021934971718533, -0.6298857927585101, -0.4528466696367042, 0.16105633716388848, -0.692134685732272, -0.46217984218188063, 0.3313304881962255, -0.7232589950587154, -0.4668464078895725, 0.416467188526974, -0.7352622619122801, -0.46871230633868777, 0.449498776426736, -0.7322332366379755, -0.4636906282225815, 0.41338432577127343]
+ body_quaternions: [0.704947045853522, -0.017326790429849574, -0.017353867329018394, 0.7088358682830036, 0.4819757108740507, -0.2853112311439719, -0.4154776056414286, 0.7167114305779071, 0.5304178240439662, 0.17988890022983534, 0.2732238783507349, 0.782077763275205, 0.653672530301817, -0.06389089372272788, -0.2826874492084857, 0.6990836737387957, 0.6052783664345818, -0.25496651211120447, -0.4773840237201027, 0.5837248244891899, 0.704947045853522, -0.017326790429849574, -0.017353867329018394, 0.7088358682830036, 0.7134973285623902, -0.4140992577997821, -0.28644280406474043, 0.4872308352552527, 0.7771152558277898, 0.2768546772870313, 0.18668972714222387, 0.5334700672048661, 0.695826658414373, -0.27621055505624104, -0.06059096635350464, 0.6601982471286629, 0.5826312132232704, -0.4701147505264745, -0.25362775120807524, 0.6125405737847457, 0.7062567956135617, -0.011967318094909873, -0.011293165079356633, 0.7077644992296305, 0.7075228034947586, -0.005114475672855655, -0.0038866239746195184, 0.7066613183326147, 0.7079425860476833, 0.021089977470155667, 0.022632191618563547, 0.7055921567116287, 0.707529562473618, 0.03078547089588323, 0.0305079963014121, 0.7053534115387, 0.7057069867264664, 0.05943561991933011, 0.05903287234969089, 0.7035340616860006, 0.7011220809038896, 0.09942042609787544, 0.09867336552524714, 0.699147318867374, 0.7128416421332835, 0.020110327697127995, 0.023614284785552723, 0.7006388038887831, -0.08815748015012756, 0.9741548577341071, -0.09232525840015009, 0.18635079422413825, -0.6839347287744285, 0.6992729902760634, 0.046209930063267894, 0.20276886893744184, 0.647753203291911, -0.7279813008080885, -0.2243164249119344, -0.011880859125505516, 0.6474224755177732, -0.7285472567370695, -0.22346549599941426, -0.011234099566255583, 0.876069505239363, -0.42712900755430927, -0.21090790150540376, 0.07470535441262363, 0.7728473667542647, -0.3172664971724538, -0.54693729964987, -0.05393058246142793, 0.702910970996137, 0.02181899943092187, 0.02191922102355977, 0.7106051265413315, 0.18597479909835998, -0.08937709736043856, 0.9741397414336744, -0.09206993389419574, 0.2006116924840915, 0.04826841898095456, 0.696750363643078, -0.6869963896054958, -0.010459592263150908, -0.22710967117462616, -0.7249804944464242, 0.6501654226893802, -0.009801628150535134, -0.22625615819816897, -0.7255428207032856, 0.6498459004171097, 0.07709318125468112, -0.21294262711122242, -0.42342554239175445, 0.8771675375965828, -0.05301920697004664, -0.5481072468001156, -0.3126502801523175, 0.7739620223358135]
+ }
+ props {
+ position: [0.9994956613474709, 0.0004992247479833129, 0.10136284191330483]
+ quaternion: [0.49647018443879004, 0.4996812760735259, 0.5019163300043502, 0.5019123190157927]
+ velocity: [-0.011991426132309257, 0.011869882254764949, 0.7667661145188961]
+ angular_velocity: [-0.10671898398231862, -0.000248768550912366, -0.10727570209153475]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.2560240947758668, -0.0018703715147317037, 0.08693244357877518]
+ quaternion: [0.7048426735504112, -0.020406851773145745, -0.02049208075388735, 0.7087738994696361]
+ joints: [0.03386003707346549, 0.007663530374561892, -1.0563940167572543, 1.5382547640444117, -0.37115271046897663, 0.33790031082317057, -0.6046389229386322, -0.03362422772543456, -0.00768649002776556, -1.0545913847572377, 1.5337126896920363, 0.3712326585555403, 0.3388325518451962, -0.6043546644309572, -0.0019739939135093613, 0.001257956055686199, 0.04556936050378683, -0.002347683691000412, 0.0012503289793459355, 0.0332411277355018, -0.0018208450144775823, 0.000986060981243297, 0.06829935311049291, -0.0036838596777776186, -0.0014101376904499558, 0.016860219833969015, -0.002164894916681967, 0.0006029116174168797, 0.07504523205031079, -0.0007472478210414381, 8.871370299082384e-05, 0.11621158925536389, -0.007945308628964467, 0.006179599789535761, 0.09380768973885034, -0.05120155051643746, 0.17418786251683416, 1.3802879451748555, 1.562229010264374, 0.0031403906028622825, 0.0003854009235979718, 0.7816815929785135, 0.00014228669168737072, 0.7813288880353004, 0.007836223506801502, -0.0050951626337969475, -0.09319756250769529, 0.05095714413844841, 0.17700846450500915, 1.3806993774089737, -1.5622350572258366, -0.0031181932618382507, 0.00038576637477516333, 0.7816788209752066, -0.00014202891891248993, 0.7813267728081049]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.45260388316797434, 0.40430461677221524, 0.10452659994205875, 0.4591349272260928, 0.3984332419407959, 0.09996060621892619, -0.3121853738880218, -0.6107194491988526, 0.21459701788635455, 0.31187051423969303, -0.6096074006887925, 0.21397126281776996]
+ velocity: [-0.0019950042200670244, 3.184589985352281e-05, -0.004990173590386357]
+ angular_velocity: [-0.07977671158580836, -0.0007017914494017582, -0.0003382842162919016]
+ joints_velocity: [0.3652364387034681, 0.09747785230721197, -0.6438838928425603, -3.9309142857918955, 0.13707747078826563, -0.2297414276128733, -0.04881687461746725, -0.37356803563850616, -0.16020746117216747, -0.6069427508199452, -4.299618506558858, -0.14761339822411293, -0.2479097376364327, -0.05218596455137056, 0.026369220084524515, 0.0014922286374551951, 0.3173077006905387, 0.018999214992559244, 0.0038259430971992465, 0.03341032075458494, 0.00476958960429413, 0.004537197119830248, -0.3332575403371123, -0.04957084398369801, 0.022116299057731038, 0.5088703783553401, -0.027089901586757534, 0.005522930997614167, 0.2519625010277204, -0.013981158084018402, 0.004556096256227826, 0.14544848852612202, 0.07919264137088781, -0.020059068584073324, -3.723487690898005, 1.8068209057392486, 1.843129097469116, -0.025030553464704758, 0.0034258075104422, -0.01905235704493633, 0.0015931222882354969, 0.0005466993431291613, -0.0011826037354277054, -0.0007982531528021434, -0.0673379711177572, 0.024965504407983286, 3.721134530846577, -1.7912209203030116, 1.8730257346139503, -0.007131562154055242, -0.00346671800337376, 0.019311732839868503, 0.0018024713566275451, 0.0005477335673379927, 0.0011800620608441378, -0.0008064363527864845]
+ appendages: [-0.45260388316797434, 0.40430461677221524, 0.10452659994205875, 0.4591349272260928, 0.3984332419407959, 0.09996060621892619, -0.3121853738880218, -0.6107194491988526, 0.21459701788635455, 0.31187051423969303, -0.6096074006887925, 0.21397126281776996, 0.002269421891503984, 0.5146457954385789, 0.07056986435676293]
+ body_positions: [-0.2560240947758668, -0.0018703715147317037, 0.08693244357877518, -0.16779725618322028, 0.10056102124242446, 0.14159963848345572, -0.005117267914833551, 0.17843010486574612, 0.5041628471211662, 0.33844918792604695, 0.31332075162585066, 0.33577812665464524, 0.4316020280737776, 0.316742872120093, 0.39185457271798135, -0.2560240947758668, -0.0018703715147317037, 0.08693244357877518, -0.16666356753068678, -0.10330982649831631, 0.14160785498158734, -0.0025015321346305575, -0.17953097230329404, 0.5038530980032339, 0.34299338850146405, -0.310719487425272, 0.3364922795188563, 0.4360849565626066, -0.31306041413412083, 0.3927258136192622, -0.2560240947758668, -0.0018703715147317037, 0.08693244357877518, -0.36913093477485337, -0.0017135044170932749, 0.07748934060142815, -0.4826468828916977, -0.0013240576833369187, 0.08080472487965146, -0.5957210981673924, -0.001242097832057136, 0.09562315117075618, -0.6845741578237263, -0.0041015051004690915, 0.11584001936729503, -0.7738936984681265, -0.0024790528844625386, 0.12762954264336576, -0.5957210981673924, -0.001242097832057136, 0.09562315117075618, -0.6695138080509674, 0.18328803192894452, 0.1112220912477744, -0.6621170668482199, 0.455016538597933, 0.1636716244394163, -0.7297299760677168, 0.45937087915677444, 0.3320903096295511, -0.7635362816978486, 0.46154803984175535, 0.4162992811275684, -0.7765457068452343, 0.46242646434066986, 0.44898917282353884, -0.7720019056308054, 0.4573370353196007, 0.4130434168275488, -0.5957210981673924, -0.001242097832057136, 0.09562315117075618, -0.6697541729333899, -0.18554620050768214, 0.1126866013120635, -0.6631725779477136, -0.45674206713041543, 0.1679272673398115, -0.7308903994684962, -0.4590965265416534, 0.3363436337219428, -0.764749161018105, -0.4602737510594101, 0.42055144582106774, -0.7777797067402005, -0.4607640062644533, 0.45324105149899396, -0.7732000412389504, -0.4560719108765266, 0.4172458199241415]
+ body_quaternions: [0.7048426735504112, -0.020406851773145745, -0.02049208075388735, 0.7087738994696361, 0.4675510767279952, -0.30420656755758774, -0.43040463122334466, 0.7096521741859316, 0.5475220544892311, 0.10660790281479293, 0.18431571466173502, 0.8092478435152342, 0.6137241360844101, -0.1234465528700522, -0.3830757509252244, 0.6792323626160796, 0.5491364153858118, -0.3005753880897024, -0.5679341152529402, 0.5343729728418214, 0.7048426735504112, -0.020406851773145745, -0.02049208075388735, 0.7087738994696361, 0.7073545787885803, -0.4280606072369331, -0.3062067659970148, 0.47185912395165946, 0.806383396433437, 0.182571086313513, 0.10691309807583585, 0.5522528459488505, 0.6755504585917973, -0.38322495493972936, -0.1258239570592103, 0.6172021902406186, 0.5308942046199554, -0.5669048700558941, -0.303802090052408, 0.5517920821175702, 0.7058461253715639, -0.0047521242435814605, -0.003931592130184531, 0.7083383423170359, 0.7066680876015369, 0.006555772168328767, 0.008263715829150288, 0.7074665693992604, 0.7066836551144652, 0.030345780729787143, 0.0327465710685319, 0.7061196833888805, 0.7077213772781614, 0.03674975338640174, 0.03824350896814494, 0.7044979359738, 0.7066041735451051, 0.06304798446320849, 0.0648698064895123, 0.7018029650762238, 0.7020123282496592, 0.10393606887488184, 0.1055558451574364, 0.6965873585742336, 0.7093783021720459, 0.028042233038331275, 0.03505886651246383, 0.7033967112892914, -0.0907291067077882, 0.976291375846323, -0.031708076670974285, 0.19395354216391225, -0.6915203071849202, 0.6951088304048317, 0.09902835452963855, 0.16975501065727952, 0.6411446337919496, -0.7152845746759443, -0.27738152211306505, 0.01900071000461578, 0.6412519161265199, -0.7155956796347399, -0.27625418358636183, 0.020060643140088045, 0.8655120953972923, -0.41733680567238157, -0.24777917741632355, 0.12379128641003517, 0.7664311851136578, -0.28506274189502917, -0.5755467173532717, -0.008273319311834228, 0.7039924186756222, 0.03227968875197962, 0.030834045922791053, 0.7088031868955965, 0.1939112386339727, -0.031264483900944676, 0.9758069835942897, -0.09602965347464207, 0.16941953541652785, 0.09937835499444436, 0.6912176566414883, -0.6954416688155061, 0.01829902795398369, -0.27874839417643266, -0.7114421785427413, 0.6448368048664389, 0.019357919702823672, -0.27763512008121066, -0.7117515236028833, 0.6449447880952642, 0.12366730020735331, -0.24932386130724343, -0.412376884873739, 0.8674614203894613, -0.010255056527450039, -0.5761518038430445, -0.2800940310336361, 0.7677833460836156]
+ }
+ props {
+ position: [0.9988960900408556, 0.0010927188607215605, 0.12621239763924963]
+ quaternion: [0.4991489590238457, 0.4970101730312855, 0.5019106590007048, 0.5019132345239578]
+ velocity: [-0.011991426132309257, 0.011869882254764949, 0.276266114518896]
+ angular_velocity: [-0.10671855794228392, -0.000248768550912366, -0.10727612591907668]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.25569992908432526, -0.0018642428763650377, 0.08709663471867946]
+ quaternion: [0.7048993068415158, -0.01937725285581231, -0.019462844066693578, 0.7087754841887155]
+ joints: [0.0631515885639572, 0.004169693140012057, -1.0892642877261176, 1.3828318956815084, -0.3630807642279871, 0.32688179108044374, -0.6071225055936477, -0.06339639081574387, -0.006222138603905286, -1.0867169053676793, 1.3653977030001396, 0.36266568318596376, 0.3268768045908207, -0.6070348871477607, -0.0011351265650508921, 0.001223518771003787, 0.040507675573732756, -0.001299270959793892, 0.0013128937955830743, 0.01885660590221743, -0.0009397291030990414, 0.0010859621062441292, 0.04693916530354042, -0.004539166543891022, -0.00044672336609063084, 0.0648898742957408, -0.0027859696767987964, 0.0006020736105026263, 0.1019243777045522, -0.0010763067744023632, 0.00017590715785518364, 0.12850794938306703, 0.02027321058002578, 0.004822434046960115, -0.04194616896017524, 0.024987359251621405, 0.2561279573986409, 1.3855728736355533, 1.562729966441628, 0.0011243037972401403, -0.0003430914539431152, 0.7814624866881186, 5.2123373906859876e-05, 0.7811386410396247, -0.020582293916294583, -0.0035372645354618936, 0.04101151572833599, -0.024564034064283383, 0.25903470283885976, 1.3865039422733008, -1.5627319433778553, -0.0011026104653882828, -0.0003063516490554674, 0.7814666703875564, -5.207144805865627e-05, 0.78113958336502]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.4460950914857898, 0.4657020434546673, 0.10044888525943281, 0.45155579925254774, 0.462640784881067, 0.09568993440523907, -0.33524461182439436, -0.6211768832273884, 0.2962585575538008, 0.3335816229156577, -0.6190469429811182, 0.2909462625899752]
+ velocity: [-0.009464592316579517, 1.7055665863492394e-05, -0.0024259998429329043]
+ angular_velocity: [-0.07014906819520904, 0.0001507089144054196, -0.002007898775813474]
+ joints_velocity: [0.6854229611539538, -0.22582885331349276, -0.5085328981509729, -2.115892247997066, 0.16156505807032967, -0.1958707628449476, -0.047454277256122654, -0.7034701750407242, 0.20603196935768955, -0.5347231143420397, -2.2353609580588554, -0.16975498091598876, -0.21342840113643347, -0.05155005294784601, 0.014520994768745764, -0.002245068167681625, -0.28194385191499743, 0.018951377298281183, -0.0007471280236298408, -0.4093733414529402, 0.017068120275740294, 5.5476985076727975e-05, -0.3960801431471277, 0.011302327471636796, 0.01497191608601326, 1.0998310178491584, 0.0007859466981119841, -0.0029290516084900468, 0.6377022412961912, -0.00039367138323310167, 0.00010286465399677658, 0.2832379374723565, 0.945701794262913, -0.021609468359945178, -1.4714854169766962, 0.9640572043639581, 1.2167430163843511, 0.33947637874223136, 0.011780439423220221, -0.04600990610056051, -0.014500662668075991, -0.004798148714332736, -0.002001188186454031, -0.004247978463962675, -0.9383260887288267, 0.02441775793856877, 1.4384816584873623, -0.9511127679856959, 1.2029434647046828, 0.33815050238198596, -0.01166009292625669, 0.04577903313119494, -0.01378588913566588, -0.004639962094004206, 0.0019934244951556566, -0.004170318476796304]
+ appendages: [-0.4460950914857898, 0.4657020434546673, 0.10044888525943281, 0.45155579925254774, 0.462640784881067, 0.09568993440523907, -0.33524461182439436, -0.6211768832273884, 0.2962585575538008, 0.3335816229156577, -0.6190469429811182, 0.2909462625899752, 0.0013759521814354708, 0.5147645746859154, 0.06684713177850474]
+ body_positions: [-0.25569992908432526, -0.0018642428763650377, 0.08709663471867946, -0.16730625695075546, 0.10056056535117719, 0.14150605352888335, -0.016456460526455108, 0.1779091665314929, 0.5092586685287035, 0.3445998595008844, 0.3350234646182242, 0.4115845094381812, 0.42595897838294694, 0.33682385947050314, 0.4837730313268466, -0.25569992908432526, -0.0018642428763650377, 0.08709663471867946, -0.16618842017732166, -0.1033103699992213, 0.14151207861041581, -0.013943794899921952, -0.1800556397227964, 0.5088160272336391, 0.35010203270125373, -0.3337824945689379, 0.41702550618543927, 0.43040956104293976, -0.33447525898431557, 0.49040091970028177, -0.25569992908432526, -0.0018642428763650377, 0.08709663471867946, -0.36878630231426607, -0.0017941765742777419, 0.07741049327353795, -0.48234207255267325, -0.0016146914414572305, 0.07884880797693415, -0.5958962549902604, -0.0018493339968680255, 0.08937321908533458, -0.6845342220899222, -0.004855868712909581, 0.11049277678277243, -0.77335813522888, -0.0033254842622178374, 0.1255783967339999, -0.5958962549902604, -0.0018493339968680255, 0.08937321908533458, -0.6758933873265804, 0.18026072916292757, 0.10265642152302207, -0.7253657563963872, 0.44712521049473075, 0.15722637617875332, -0.8109553992172212, 0.4333263051586591, 0.316723468783378, -0.8537500320377324, 0.4264268828954032, 0.3964716636466532, -0.8703215990057782, 0.42378382941988607, 0.4274077789152051, -0.8608401280308875, 0.42081993262046363, 0.39219467318346535, -0.5958962549902604, -0.0018493339968680255, 0.08937321908533458, -0.675197710271082, -0.18414466339163188, 0.1041972476976385, -0.7237618497059807, -0.4505302434377207, 0.1618365897606474, -0.8092830696932304, -0.43502632710577116, 0.3212137173739419, -0.8520434912477137, -0.4272744031014312, 0.40090193000588525, -0.8686015591124719, -0.4242990652844558, 0.4318150913586317, -0.8591522290274858, -0.421740196239765, 0.3965616072539776]
+ body_quaternions: [0.7048993068415158, -0.01937725285581231, -0.019462844066693578, 0.7087754841887155, 0.4533162507427195, -0.30346855298825604, -0.44609325523303717, 0.7095153428893349, 0.5427096509889634, 0.055272240443640375, 0.10876203090241854, 0.8310126562175569, 0.5688809327924056, -0.15631870075513477, -0.4630922629455469, 0.6614261138626458, 0.49614369777612555, -0.31922168325093986, -0.6396326017250292, 0.4927566163025509, 0.7048993068415158, -0.01937725285581231, -0.019462844066693578, 0.7087754841887155, 0.707899789312278, -0.44307926599488, -0.3060034720590829, 0.45707825087708664, 0.8287742024080129, 0.10283321003397333, 0.05094597869083459, 0.5476889259354291, 0.6564501447778921, -0.4676664167347415, -0.1624723254727868, 0.5691784197727779, 0.4866689973958854, -0.6424888768641152, -0.32516120913677726, 0.4946023838926384, 0.7055695779761122, -0.005511170512254434, -0.004692359920002123, 0.7086036828810426, 0.7060576274781435, 0.0006838774772373696, 0.0024441937155121984, 0.7081498322449935, 0.7061873118735406, 0.01687544456015664, 0.019437966441395026, 0.7075569697065693, 0.7068720681192221, 0.03994038389219727, 0.0422083960250504, 0.7049504212016547, 0.7049008527093986, 0.07567562003797146, 0.07828261585118922, 0.7008992940764931, 0.6989632173227789, 0.12070121177529647, 0.12320425276528518, 0.6940478012401686, 0.6989304424757922, 0.015348315465699631, 0.020951243278772846, 0.7147179242176822, -0.11832398519894324, 0.9614860644918305, 0.06341463317280045, 0.239838626193137, -0.705132192668028, 0.6642624545678667, 0.20197664669583573, 0.14404657755100497, 0.6273020463452265, -0.6787725140434453, -0.3801290801798868, 0.03552322102267247, 0.6271654913889255, -0.6790936946170586, -0.3797535956411918, 0.03581070984813245, 0.8385393117975031, -0.3890457842310198, -0.33749283846821526, 0.17774640453440435, 0.7320069744972056, -0.21997113069006938, -0.643492367270073, 0.04118330020211258, 0.7134649291898224, 0.017912977338250608, 0.01834877401000351, 0.7002215667568165, 0.2400853733689033, 0.060435957585594086, 0.9609745570688857, -0.12342774887068805, 0.14604582434108673, 0.1999102515317958, 0.6602791634666231, -0.7090401503559887, 0.03254707147194902, -0.3791613138307707, -0.6754543986489769, 0.6316159763483447, 0.032837258017103854, -0.3787939090023924, -0.6757600672146964, 0.6314944342958175, 0.17463225516588488, -0.33773760001970743, -0.3843129085229536, 0.8412731288536641, 0.03628835715327525, -0.6431694526863097, -0.2163729133630063, 0.7336204554325857]
+ }
+ props {
+ position: [0.9990628057030198, 0.0009313441928349641, 0.12764391095654046]
+ quaternion: [0.4987639746022105, 0.4974010120003192, 0.5019174014145118, 0.5019020353190128]
+ velocity: [0.04653603614214268, -0.0457816699847161, -0.13000932762891504]
+ angular_velocity: [0.3584067352540253, -0.0001108899265487697, 0.3577181593170936]
+ }
+}
+timesteps {
+ walkers {
+ position: [-0.2589338851715894, -0.0018784374308002223, 0.08536131971428179]
+ quaternion: [0.7044755346484229, -0.031812867644570725, -0.03195269894136319, 0.7082945627099014]
+ joints: [0.08875126457799538, -0.012967733893576555, -1.0900788679801792, 1.3608001879877907, -0.35766403187773455, 0.32120608143600765, -0.6088391272954923, -0.09036558880162303, 0.01051880744395801, -1.0923190258396336, 1.3474358490747516, 0.35710967869780874, 0.3206252549890462, -0.6089184473688815, -0.0003826348852714126, 0.0009718033720418631, 0.031211773456739127, -0.00040037452477553563, 0.001098387655969389, 0.00655633473450552, -0.0001712845414809236, 0.0009170505903830211, 0.03730145262262089, -0.0032738058706338314, -7.1757107555202535e-06, 0.10977396436374744, -0.002253567930270911, 0.00041750077510282774, 0.12866016605168873, -0.0008658691307699896, 0.00014580097489509932, 0.14041659615334512, 0.05179309656500675, 0.0016797105091955278, -0.06969973820585103, 0.042095673994173646, 0.2970017147014684, 1.4034240516277323, 1.5631441684972884, -0.0006080653942662456, -0.00038861601375978134, 0.7813884619547344, -2.6488915953525288e-05, 0.781020232101423, -0.05154325171079399, -0.00045398297083643177, 0.06783141960145339, -0.041126428191386705, 0.29883687811227033, 1.4039013604892698, -1.563140773599554, 0.0006104073323316494, -0.0003402987300704897, 0.7813959476931572, 2.5775488898650408e-05, 0.7810234920378067]
+ center_of_mass: [0.0, 0.0, 0.0]
+ end_effectors: [-0.43825301369038344, 0.4933358038046371, 0.09733222045887141, 0.441773178847443, 0.4930111864909349, 0.09346576387449497, -0.34581472766396903, -0.6169586961298018, 0.3065202972595384, 0.34411684274246374, -0.6168214608477081, 0.29998185458902094]
+ velocity: [-0.1100940425917514, -0.0005348574343312351, -0.0622043257733096]
+ angular_velocity: [-1.2492347783633728, 0.0006848427735458744, -0.0007593700258209132]
+ joints_velocity: [0.26778224185231436, -0.3437469238772742, 0.3676761347944968, 0.9670530180772898, 0.04813829836188294, -0.030584948669908452, -0.01945915792840022, -0.29815509418627933, 0.3399150937665342, 0.2259136202191499, 1.1846062128406696, -0.046314743190786836, -0.035434469580499006, -0.021732253244577372, 0.016652364816128258, -0.007399082537357797, -0.08765082597986062, 0.017419780069987664, -0.007564650481632352, -0.07828732131968298, 0.013019717253623631, -0.006613408384884584, -0.015613698128663325, 0.03054751882002135, 0.0038161945260472008, 0.6615367044073728, 0.016404131276386576, -0.003704663497979105, 0.4073469498781107, 0.006680492887916364, -0.0008425713967623944, 0.18332729526358152, 0.24956993523865406, -0.096522681072924, 0.05423743903665995, -0.059565393942455275, 0.4518917245750971, 0.23305306046185773, 0.004684372046913005, -0.022464160733721726, 0.008172257040350555, 0.0008734282489324963, -0.001053305469381935, -0.0008093509260699144, -0.2402123394754958, 0.0911746989386925, -0.0601219416369371, 0.06638496681891413, 0.42678340289779504, 0.22201662202789885, -0.004601307570776698, 0.022029951573484193, 0.008032564627832237, 0.0008634505611443933, 0.001034264142888148, -0.000792524080968282]
+ appendages: [-0.43825301369038344, 0.4933358038046371, 0.09733222045887141, 0.441773178847443, 0.4930111864909349, 0.09346576387449497, -0.34581472766396903, -0.6169586961298018, 0.3065202972595384, 0.34411684274246374, -0.6168214608477081, 0.29998185458902094, 0.00010321201584700578, 0.5140949282118228, 0.06453662266480031]
+ body_positions: [-0.2589338851715894, -0.0018784374308002223, 0.08536131971428179, -0.17250639730994188, 0.1005297843852689, 0.14287230053851713, -0.03439020973687312, 0.17646587566453642, 0.5158844743298263, 0.32648680385413764, 0.34542357006478397, 0.4396749682718708, 0.40361175052196685, 0.34780736884710284, 0.5163547599004996, -0.2589338851715894, -0.0018784374308002223, 0.08536131971428179, -0.17140459594100899, -0.10334123813116401, 0.14288167324064383, -0.033727711444801745, -0.17847871766820542, 0.5162178323107911, 0.32976307439550934, -0.344500068923587, 0.44623090649216546, 0.4056754588898466, -0.345875917634121, 0.5241356384413204, -0.2589338851715894, -0.0018784374308002223, 0.08536131971428179, -0.37147643707145034, -0.0018824632326525277, 0.07064642755213985, -0.4849305396142878, -0.0018875160159372718, 0.06562840909805498, -0.598932168640416, -0.0024053270947531568, 0.06858325028701501, -0.6879945281494322, -0.005738232011731522, 0.08778095961380826, -0.7767453559817022, -0.004572860550552978, 0.10332245309548896, -0.598932168640416, -0.0024053270947531568, 0.06858325028701501, -0.6857817550624453, 0.17682562620030617, 0.07715497545565708, -0.7607406749810872, 0.4371914634773306, 0.13401302460144357, -0.856626124757387, 0.40866183290804603, 0.28549688511726506, -0.9045686383696472, 0.394397080486154, 0.3612384815926522, -0.9231791611336606, 0.3888639482377536, 0.3905928455157872, -0.9107707480805859, 0.38840433716050454, 0.356176776092886, -0.598932168640416, -0.0024053270947531568, 0.06858325028701501, -0.6842124733140329, -0.1823222581133838, 0.0784436417960984, -0.7566563319124348, -0.44282533062744645, 0.13787497369357304, -0.8526027397457804, -0.4136684121140389, 0.289200715795437, -0.9005757322522473, -0.39909001710226394, 0.36486325341224657, -0.9191971869480815, -0.39343420689247893, 0.3941872883877046, -0.906838699455358, -0.3932044682600658, 0.3597509587573132]
+ body_quaternions: [0.7044755346484229, -0.031812867644570725, -0.03195269894136319, 0.7082945627099014, 0.4394935814814895, -0.30449210721518827, -0.46161053385288525, 0.7078457907738674, 0.533184851218239, 0.0397990697272791, 0.08648907324357237, 0.8406245230123232, 0.548412193247333, -0.16285438740640804, -0.48761929614007826, 0.6595073440318828, 0.4743829110490296, -0.319747304765052, -0.6628794803650989, 0.48302516426272246, 0.7044755346484229, -0.031812867644570725, -0.03195269894136319, 0.7082945627099014, 0.7056200711029307, -0.460392631309297, -0.30801843234284504, 0.44188639448375566, 0.8386844249620079, 0.08043317122524117, 0.03497239289898182, 0.5375089506489322, 0.6547058327202665, -0.49255991898704377, -0.16919146786595418, 0.5478314029068645, 0.4769371014274121, -0.6661726521499355, -0.3256376438775012, 0.471916437199902, 0.7050423966765198, -0.021151056587525734, -0.020561789242236797, 0.7085513845283333, 0.7052622232072603, -0.019224131102913152, -0.017856383955287967, 0.708460852025101, 0.7055728904649105, -0.0063898466166814015, -0.004320319250257996, 0.7085953718045177, 0.7060175704754205, 0.03240028396138757, 0.03448294259268927, 0.7066118725646744, 0.7032731904386288, 0.07758497527234876, 0.07996958586375463, 0.7021340730651383, 0.6964013403369623, 0.12666422663474644, 0.12909142030983267, 0.6944902822000423, 0.6869914367711893, -0.007109832651858841, -0.0035764421452735154, 0.7266219272388503, -0.13836805480694886, 0.9461384984188699, 0.10302096594882515, 0.2739761007708494, -0.7164341252635894, 0.6332818653168335, 0.2555397854714512, 0.14274327045374968, 0.6242253136429275, -0.6485436667259392, -0.4346506592752258, 0.028507447249345168, 0.6241079486610748, -0.6485327701550642, -0.43485331180300224, 0.02823316608219809, 0.8240592406129768, -0.36198322268759936, -0.3913328477106524, 0.19171102408504162, 0.7079091116734603, -0.17656678550166527, -0.6819984884609105, 0.05066479649948982, 0.7235971879066638, -0.006119730651298536, -0.004647793628640624, 0.6901797277265607, 0.2745034449437161, 0.09719292038087397, 0.9461166303446573, -0.14164997962280335, 0.14684817041096682, 0.2514641523260966, 0.6309750024946268, -0.7190771454918972, 0.023942030322665897, -0.43141285648972233, -0.6473952564779659, 0.6278448122977809, 0.023676970089058436, -0.43161445866869286, -0.6473703768069542, 0.6277419496802452, 0.18626616688310293, -0.3900724995678919, -0.35952138819120927, 0.8269780720119483, 0.04417919339878007, -0.6809515449311216, -0.17610465218973015, 0.7094648291520067]
+ }
+ props {
+ position: [0.9993826899103103, 0.0006124708016742434, 0.12694550340154664]
+ quaternion: [0.49752659686213496, 0.49864018187299153, 0.5019145365658443, 0.5019034293769534]
+ velocity: [-0.005000523197381878, 0.004858290674399023, 0.015948639890104466]
+ angular_velocity: [-0.038992755454138814, -3.6388304643072614e-05, -0.03846776459978903]
+ }
+}
diff --git a/dm_control/locomotion/mocap/test_trajectories.h5 b/dm_control/locomotion/mocap/test_trajectories.h5
new file mode 100644
index 00000000..2e018461
Binary files /dev/null and b/dm_control/locomotion/mocap/test_trajectories.h5 differ
diff --git a/dm_control/locomotion/mocap/trajectory.py b/dm_control/locomotion/mocap/trajectory.py
new file mode 100644
index 00000000..6d070acd
--- /dev/null
+++ b/dm_control/locomotion/mocap/trajectory.py
@@ -0,0 +1,277 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Represents a motion-captured trajectory."""
+
+import collections
+import copy
+
+from dm_control.locomotion.mocap import mocap_pb2
+from dm_control.locomotion.mocap import props as mocap_props
+from dm_control.locomotion.mocap import walkers as mocap_walkers
+import numpy as np
+
+STEP_TIME_TOLERANCE = 1e-4
+
+_REPEATED_POSITION_FIELDS = ('end_effectors', 'appendages', 'body_positions')
+_REPEATED_QUATERNION_FIELDS = ('body_quaternions',)
+
+
+def _zero_out_velocities(timestep_proto):
+ out_proto = copy.deepcopy(timestep_proto)
+ for walker in out_proto.walkers:
+ walker.velocity[:] = np.zeros_like(walker.velocity)
+ walker.angular_velocity[:] = np.zeros_like(walker.angular_velocity)
+ walker.joints_velocity[:] = np.zeros_like(walker.joints_velocity)
+ for prop in out_proto.props:
+ prop.velocity[:] = np.zeros_like(prop.velocity)
+ prop.angular_velocity[:] = np.zeros_like(prop.angular_velocity)
+ return out_proto
+
+
+class Trajectory:
+ """Represents a motion-captured trajectory."""
+
+ def __init__(self, proto, start_time=None, end_time=None, start_step=None,
+ end_step=None, zero_out_velocities=True):
+ """A wrapper around a mocap trajectory proto.
+
+ Args:
+ proto: proto representing the mocap trajectory.
+ start_time: Start time of the mocap trajectory if only a subset of the
+ underlying clip is desired. Defaults to the start of the full clip.
+ Cannot be used when start_step is provided.
+ end_time: End time of the mocap trajectory if only a subset of the
+ underlying clip is desired. Defaults to the end of the full clip.
+ Cannot be used when end_step is provided.
+ start_step: Like start_time but using time indices. Defaults to the start
+ of the full clip. Cannot be used when start_time is provided.
+ end_step: Like end_time but using time indices. Defaults to the start
+ of the full clip. Cannot be used when end_time is provided.
+ zero_out_velocities: Whether to zero out the velocities in the last time
+ step of the requested trajectory. Depending on the use-case it may be
+ beneficial to use a stable end pose.
+ """
+ self._proto = proto
+ self._zero_out_velocities = zero_out_velocities
+
+ if (start_time and start_step) or (end_time and end_step):
+ raise ValueError(('Please specify either start and end times'
+ 'or start and end steps but not both.'))
+ if start_step:
+ start_time = start_step * self._proto.dt
+ if end_step:
+ end_time = end_step * self._proto.dt
+ self._set_start_time(start_time or 0.)
+ self._set_end_time(end_time or (len(self._proto.timesteps)*self._proto.dt))
+ self._walkers_info = tuple(mocap_walkers.WalkerInfo(walker_proto)
+ for walker_proto in self._proto.walkers)
+ self._dict = None
+
+ def as_dict(self):
+ """Return trajectory as dictionary."""
+ if self._dict is None:
+ self._dict = dict()
+
+ if self._proto.timesteps:
+ initial_timestep = self._proto.timesteps[0]
+
+ num_walkers = len(initial_timestep.walkers)
+ for i in range(num_walkers):
+ key_prefix = 'walker_{:d}/'.format(
+ i) if num_walkers > 1 else 'walker/'
+ for field in mocap_pb2.WalkerPose.DESCRIPTOR.fields:
+ field_name = field.name
+
+ def walker_field(timestep, i=i, field_name=field_name):
+ values = getattr(timestep.walkers[i], field_name)
+ if field_name in _REPEATED_POSITION_FIELDS:
+ values = np.reshape(values, (-1, 3))
+ elif field_name in _REPEATED_QUATERNION_FIELDS:
+ values = np.reshape(values, (-1, 4))
+ return np.array(values)
+
+ self._dict[key_prefix + field_name] = walker_field
+
+ num_props = len(initial_timestep.props)
+ for i in range(len(initial_timestep.props)):
+ key_prefix = 'prop_{:d}/'.format(i) if num_props > 1 else 'prop/'
+ for field in mocap_pb2.PropPose.DESCRIPTOR.fields:
+ field_name = field.name
+
+ def prop_field(timestep, i=i, field_name=field_name):
+ return np.array(getattr(timestep.props[i], field_name))
+
+ self._dict[key_prefix + field_name] = prop_field
+
+ self._create_all_items(self._dict)
+ for k in self._dict:
+ # make trajectory immutable by default
+ self._dict[k].flags.writeable = False # pytype: disable=attribute-error
+
+ return {k: v[self._start_step:self._end_step]
+ for k, v in self._dict.items()}
+
+ def _create_single_item(self, get_field_in_timestep):
+ if not self._proto.timesteps:
+ return np.empty((0))
+ for i, timestep in enumerate(self._proto.timesteps):
+ values = get_field_in_timestep(timestep)
+ if i == 0:
+ array = np.empty((len(self._proto.timesteps),) + values.shape)
+ array[i, :] = values
+ return array
+
+ def _create_all_items(self, dictionary):
+ for key, value in dictionary.items():
+ if callable(value):
+ dictionary[key] = self._create_single_item(value)
+ return dictionary
+
+ def _get_quantized_time(self, time):
+ if time == float('inf'):
+ return len(self._proto.timesteps) - 1
+ else:
+ divided_time = time / self._proto.dt
+ quantized_time = int(np.round(divided_time))
+ if np.abs(quantized_time - divided_time) > STEP_TIME_TOLERANCE:
+ raise ValueError('`time` should be a multiple of dt = {}: got {}'
+ .format(self._proto.dt, time))
+ return quantized_time
+
+ def _get_step_id(self, time):
+ quantized_time = self._get_quantized_time(time)
+ return np.clip(quantized_time + self._start_step,
+ self._start_step, self._end_step - 1)
+
+ def get_modified_trajectory(self, proto_modifier, random_state=None):
+ modified_proto = copy.deepcopy(self._proto)
+ if isinstance(proto_modifier, collections.abc.Iterable):
+ for proto_mod in proto_modifier:
+ proto_mod(modified_proto, random_state=random_state)
+ else:
+ proto_modifier(modified_proto, random_state=random_state)
+ return type(self)(modified_proto, self.start_time, self.end_time)
+
+ @property
+ def identifier(self):
+ return self._proto.identifier
+
+ @property
+ def start_time(self):
+ return self._start_step * self._proto.dt
+
+ def _set_start_time(self, new_value):
+ self._start_step = np.clip(self._get_quantized_time(new_value),
+ 0, len(self._proto.timesteps) - 1)
+
+ @start_time.setter
+ def start_time(self, new_value):
+ self._set_start_time(new_value)
+
+ @property
+ def start_step(self):
+ return self._start_step
+
+ @start_step.setter
+ def start_step(self, new_value):
+ self._start_step = np.clip(int(new_value), 0,
+ len(self._proto.timesteps) - 1)
+
+ @property
+ def end_step(self):
+ return self._end_step
+
+ @end_step.setter
+ def end_step(self, new_value):
+ self._end_step = np.clip(int(new_value), 0,
+ len(self._proto.timesteps) - 1)
+
+ @property
+ def end_time(self):
+ return (self._end_step - 1) * self._proto.dt
+
+ @property
+ def clip_end_time(self):
+ """Length of the full clip."""
+ return (len(self._proto.timesteps) -1) * self._proto.dt
+
+ def _set_end_time(self, new_value):
+ self._end_step = 1 + np.clip(self._get_quantized_time(new_value),
+ 0, len(self._proto.timesteps) - 1)
+ if self._zero_out_velocities:
+ self._last_timestep = _zero_out_velocities(
+ self._proto.timesteps[self._end_step - 1])
+ else:
+ self._last_timestep = self._proto.timesteps[self._end_step - 1]
+
+ @end_time.setter
+ def end_time(self, new_value):
+ self._set_end_time(new_value)
+
+ @property
+ def duration(self):
+ return self.end_time - self.start_time
+
+ @property
+ def num_steps(self):
+ return self._end_step - self._start_step
+
+ @property
+ def dt(self):
+ return self._proto.dt
+
+ def configure_walkers(self, walkers):
+ try:
+ walkers = iter(walkers)
+ except TypeError:
+ walkers = iter((walkers,))
+ for walker, walker_info in zip(walkers, self._walkers_info):
+ walker_info.rescale_walker(walker)
+ walker_info.add_marker_sites(walker)
+
+ def create_props(self,
+ proto_modifier=None,
+ priority_friction=False,
+ prop_factory=None):
+ proto = self._proto
+ prop_factory = prop_factory or mocap_props.Prop
+ if proto_modifier is not None:
+ proto = copy.copy(proto)
+ proto_modifier(proto)
+ return tuple(
+ prop_factory(prop_proto, priority_friction=priority_friction)
+ for prop_proto in proto.props)
+
+ def get_timestep_data(self, time):
+ step_id = self._get_step_id(time)
+ if step_id == self._end_step - 1:
+ return self._last_timestep
+ else:
+ return self._proto.timesteps[step_id]
+
+ def set_walker_poses(self, physics, walkers):
+ timestep = self._proto.timesteps[self._get_step_id(physics.time())]
+ for walker, walker_timestep in zip(walkers, timestep.walkers):
+ walker.set_pose(physics,
+ position=walker_timestep.position,
+ quaternion=walker_timestep.quaternion)
+ physics.bind(walker.mocap_joints).qpos = walker_timestep.joints
+
+ def set_prop_poses(self, physics, props):
+ timestep = self._proto.timesteps[self._get_step_id(physics.time())]
+ for prop, prop_timestep in zip(props, timestep.props):
+ prop.set_pose(physics,
+ position=prop_timestep.position,
+ quaternion=prop_timestep.quaternion)
diff --git a/dm_control/locomotion/mocap/walkers.py b/dm_control/locomotion/mocap/walkers.py
new file mode 100644
index 00000000..a5fdbab9
--- /dev/null
+++ b/dm_control/locomotion/mocap/walkers.py
@@ -0,0 +1,97 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Helpers for modifying a walker to match mocap data."""
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.locomotion.mocap import mocap_pb2
+from dm_control.locomotion.walkers import rescale
+import numpy as np
+
+
+class WalkerInfo:
+ """Encapsulates routines that modify a walker to match mocap data."""
+
+ def __init__(self, proto):
+ """Initializes this object.
+
+ Args:
+ proto: A `mocap_pb2.Walker` protocol buffer.
+ """
+ self._proto = proto
+
+ def check_walker_is_compatible(self, walker):
+ """Checks whether a given walker is compatible with this `WalkerInfo`."""
+ mocap_model = getattr(walker, 'mocap_walker_model', None)
+ if mocap_model is not None and mocap_model != self._proto.model:
+ model_type_name = list(mocap_pb2.Walker.Model.keys())[list(
+ mocap_pb2.Walker.Model.values()).index(self._proto.model)]
+ raise ValueError('Walker is not compatible with model type {!r}: got {}'
+ .format(model_type_name, walker))
+
+ def rescale_walker(self, walker):
+ """Rescales a given walker to match the data in this `WalkerInfo`."""
+ self.check_walker_is_compatible(walker)
+ for subtree_info in self._proto.scaling.subtree:
+ body = walker.mjcf_model.find('body', subtree_info.body_name)
+ subtree_root = body.parent
+ if subtree_info.parent_length:
+ position_factor = subtree_info.parent_length / np.linalg.norm(body.pos)
+ else:
+ position_factor = subtree_info.size_factor
+ rescale.rescale_subtree(
+ subtree_root, position_factor, subtree_info.size_factor)
+
+ if self._proto.mass:
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model.root_model)
+ current_mass = physics.bind(walker.root_body).subtreemass
+ mass_factor = self._proto.mass / current_mass
+ for body in walker.root_body.find_all('body'):
+ inertial = getattr(body, 'inertial', None)
+ if inertial:
+ inertial.mass *= mass_factor
+ for geom in walker.root_body.find_all('geom'):
+ if geom.mass is not None:
+ geom.mass *= mass_factor
+ else:
+ current_density = geom.density if geom.density is not None else 1000
+ geom.density = current_density * mass_factor
+
+ def add_marker_sites(self, walker, size=0.01, rgba=(0., 0., 1., .3),
+ default_to_random_position=True, random_state=None):
+ """Adds sites corresponding to mocap tracking markers."""
+ self.check_walker_is_compatible(walker)
+ random_state = random_state or np.random
+ sites = []
+ if self._proto.markers:
+ mocap_class = walker.mjcf_model.default.add('default', dclass='mocap')
+ mocap_class.site.set_attributes(type='sphere', size=(size,), rgba=rgba,
+ group=composer.SENSOR_SITES_GROUP)
+ for marker_info in self._proto.markers.marker:
+ body = walker.mjcf_model.find('body', marker_info.parent)
+ if not body:
+ raise ValueError('Walker model does not contain a body named {!r}'
+ .format(str(marker_info.parent)))
+ pos = marker_info.position
+ if not pos:
+ if default_to_random_position:
+ pos = random_state.uniform(-0.005, 0.005, size=3)
+ else:
+ pos = np.zeros(3)
+ sites.append(
+ body.add(
+ 'site', name=str(marker_info.name), pos=pos, dclass=mocap_class))
+ walker.list_of_site_names = [site.name for site in sites]
+ return sites
diff --git a/dm_control/locomotion/props/__init__.py b/dm_control/locomotion/props/__init__.py
new file mode 100644
index 00000000..84e687e0
--- /dev/null
+++ b/dm_control/locomotion/props/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Props for Locomotion tasks."""
+
+from dm_control.locomotion.props.target_sphere import TargetSphere
+from dm_control.locomotion.props.target_sphere import TargetSphereTwoTouch
diff --git a/dm_control/locomotion/props/target_sphere.py b/dm_control/locomotion/props/target_sphere.py
new file mode 100644
index 00000000..85ace41a
--- /dev/null
+++ b/dm_control/locomotion/props/target_sphere.py
@@ -0,0 +1,224 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A non-colliding sphere that is activated through touch."""
+
+
+from dm_control import composer
+from dm_control import mjcf
+
+
+class TargetSphere(composer.Entity):
+ """A non-colliding sphere that is activated through touch.
+
+ Once the target has been reached, it remains in the "activated" state
+ for the remainder of the current episode.
+
+ The target is automatically reset to "not activated" state at episode
+ initialization time.
+ """
+
+ def _build(self,
+ radius=0.6,
+ height_above_ground=1,
+ rgb1=(0, 0.4, 0),
+ rgb2=(0, 0.7, 0),
+ specific_collision_geom_ids=None,
+ name='target'):
+ """Builds this target sphere.
+
+ Args:
+ radius: The radius (in meters) of this target sphere.
+ height_above_ground: The height (in meters) of this target above ground.
+ rgb1: A sequence of three floating point values between 0.0 and 1.0
+ (inclusive) representing the color of the first element in the stripe
+ pattern of the target.
+ rgb2: A sequence of three floating point values between 0.0 and 1.0
+ (inclusive) representing the color of the second element in the stripe
+ pattern of the target.
+ specific_collision_geom_ids: Only activate if collides with these geoms.
+ name: The name of this entity.
+ """
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._texture = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere', type='cube',
+ builtin='checker', rgb1=rgb1, rgb2=rgb2,
+ width='100', height='100')
+ self._material = self._mjcf_root.asset.add(
+ 'material', name='target_sphere', texture=self._texture)
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', type='sphere', name='geom', gap=2*radius,
+ pos=[0, 0, height_above_ground], size=[radius], material=self._material)
+ self._geom_id = -1
+ self._activated = False
+ self._specific_collision_geom_ids = specific_collision_geom_ids
+
+ @property
+ def geom(self):
+ return self._geom
+
+ @property
+ def material(self):
+ return self._material
+
+ @property
+ def activated(self):
+ """Whether this target has been reached during this episode."""
+ return self._activated
+
+ def reset(self, physics):
+ self._activated = False
+ physics.bind(self._material).rgba[-1] = 1
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._activated = False
+
+ def _update_activation(self, physics):
+ if not self._activated:
+ for contact in physics.data.contact:
+ if self._specific_collision_geom_ids:
+ has_specific_collision = (
+ contact.geom1 in self._specific_collision_geom_ids or
+ contact.geom2 in self._specific_collision_geom_ids)
+ else:
+ has_specific_collision = True
+ if (has_specific_collision and
+ self._geom_id in (contact.geom1, contact.geom2)):
+ self._activated = True
+ physics.bind(self._material).rgba[-1] = 0
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom')
+ self._update_activation(physics)
+
+ def after_substep(self, physics, unused_random_state):
+ self._update_activation(physics)
+
+
+class TargetSphereTwoTouch(composer.Entity):
+ """A non-colliding sphere that is activated through touch.
+
+ The target indicates if it has been touched at least once and touched at least
+ twice this episode with a two-bit activated state tuple. It remains activated
+ for the remainder of the current episode.
+
+ The target is automatically reset at episode initialization.
+ """
+
+ def _build(self,
+ radius=0.6,
+ height_above_ground=1,
+ rgb_initial=((0, 0.4, 0), (0, 0.7, 0)),
+ rgb_interval=((1., 1., .4), (0.7, 0.7, 0.)),
+ rgb_final=((.4, 0.7, 1.), (0, 0.4, .7)),
+ touch_debounce=.2,
+ specific_collision_geom_ids=None,
+ name='target'):
+ """Builds this target sphere.
+
+ Args:
+ radius: The radius (in meters) of this target sphere.
+ height_above_ground: The height (in meters) of this target above ground.
+ rgb_initial: A tuple of two colors for the stripe pattern of the target.
+ rgb_interval: A tuple of two colors for the stripe pattern of the target.
+ rgb_final: A tuple of two colors for the stripe pattern of the target.
+ touch_debounce: duration to not count second touch.
+ specific_collision_geom_ids: Only activate if collides with these geoms.
+ name: The name of this entity.
+ """
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._texture_initial = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere_init', type='cube',
+ builtin='checker', rgb1=rgb_initial[0], rgb2=rgb_initial[1],
+ width='100', height='100')
+ self._texture_interval = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere_inter', type='cube',
+ builtin='checker', rgb1=rgb_interval[0], rgb2=rgb_interval[1],
+ width='100', height='100')
+ self._texture_final = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere_final', type='cube',
+ builtin='checker', rgb1=rgb_final[0], rgb2=rgb_final[1],
+ width='100', height='100')
+ self._material = self._mjcf_root.asset.add(
+ 'material', name='target_sphere_init', texture=self._texture_initial)
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', type='sphere', name='geom', gap=2*radius,
+ pos=[0, 0, height_above_ground], size=[radius],
+ material=self._material)
+ self._geom_id = -1
+ self._touched_once = False
+ self._touched_twice = False
+ self._touch_debounce = touch_debounce
+ self._specific_collision_geom_ids = specific_collision_geom_ids
+
+ @property
+ def geom(self):
+ return self._geom
+
+ @property
+ def material(self):
+ return self._material
+
+ @property
+ def activated(self):
+ """Whether this target has been reached during this episode."""
+ return (self._touched_once, self._touched_twice)
+
+ def reset(self, physics):
+ self._touched_once = False
+ self._touched_twice = False
+ self._geom.material = self._material
+ physics.bind(self._material).texid = physics.bind(
+ self._texture_initial).element_id
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._touched_once = False
+ self._touched_twice = False
+
+ def _update_activation(self, physics):
+ if not (self._touched_once and self._touched_twice):
+ for contact in physics.data.contact:
+ if self._specific_collision_geom_ids:
+ has_specific_collision = (
+ contact.geom1 in self._specific_collision_geom_ids or
+ contact.geom2 in self._specific_collision_geom_ids)
+ else:
+ has_specific_collision = True
+ if (has_specific_collision and
+ self._geom_id in (contact.geom1, contact.geom2)):
+ if not self._touched_once:
+ self._touched_once = True
+ self._touch_time = physics.time()
+ physics.bind(self._material).texid = physics.bind(
+ self._texture_interval).element_id
+ if self._touched_once and (
+ physics.time() > (self._touch_time + self._touch_debounce)):
+ self._touched_twice = True
+ physics.bind(self._material).texid = physics.bind(
+ self._texture_final).element_id
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom')
+ self._update_activation(physics)
+
+ def after_substep(self, physics, unused_random_state):
+ self._update_activation(physics)
diff --git a/dm_control/locomotion/props/target_sphere_test.py b/dm_control/locomotion/props/target_sphere_test.py
new file mode 100644
index 00000000..ec67f23d
--- /dev/null
+++ b/dm_control/locomotion/props/target_sphere_test.py
@@ -0,0 +1,64 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for props.target_sphere."""
+
+from absl.testing import absltest
+from dm_control import composer
+from dm_control.entities.props import primitive
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.props import target_sphere
+
+
+class TargetSphereTest(absltest.TestCase):
+
+ def testActivation(self):
+ target_radius = 0.6
+ prop_radius = 0.1
+ target_height = 1
+
+ arena = floors.Floor()
+ target = target_sphere.TargetSphere(radius=target_radius,
+ height_above_ground=target_height)
+ prop = primitive.Primitive(geom_type='sphere', size=[prop_radius])
+ arena.attach(target)
+ arena.add_free_entity(prop)
+
+ task = composer.NullTask(arena)
+ task.initialize_episode = (
+ lambda physics, random_state: prop.set_pose(physics, [0, 0, 2]))
+
+ env = composer.Environment(task)
+ env.reset()
+
+ max_activated_height = target_height + target_radius + prop_radius
+
+ while env.physics.bind(prop.geom).xpos[2] > max_activated_height:
+ self.assertFalse(target.activated)
+ self.assertEqual(env.physics.bind(target.material).rgba[-1], 1)
+ env.step([])
+
+ while env.physics.bind(prop.geom).xpos[2] > 0.2:
+ self.assertTrue(target.activated)
+ self.assertEqual(env.physics.bind(target.material).rgba[-1], 0)
+ env.step([])
+
+ # Target should be reset when the environment is reset.
+ env.reset()
+ self.assertFalse(target.activated)
+ self.assertEqual(env.physics.bind(target.material).rgba[-1], 1)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/soccer/README.md b/dm_control/locomotion/soccer/README.md
new file mode 100644
index 00000000..e4e22ab4
--- /dev/null
+++ b/dm_control/locomotion/soccer/README.md
@@ -0,0 +1,78 @@
+# DeepMind MuJoCo Multi-Agent Soccer Environment.
+
+This submodule contains the components and environment used in the following
+works.
+
+* [Emergent Coordination through Competition][boxhead]
+ ([dynamic team play](https://www.youtube.com/watch?v=8nU35D8vAlo),
+ [defensive team play](https://www.youtube.com/watch?v=-gFQqB8L_mI)).
+
+* [From Motor Control to Team Play in Simulated Humanoid Football][humanoid].
+
+
+
+## Quickstart
+
+```python
+import numpy as np
+from dm_control.locomotion import soccer as dm_soccer
+
+# Instantiates a 2-vs-2 BOXHEAD soccer environment with episodes of 10 seconds
+# each. Upon scoring, the environment reset player positions and the episode
+# continues. In this example, players can physically block each other and the
+# ball is trapped within an invisible box encapsulating the field.
+env = dm_soccer.load(team_size=2,
+ time_limit=10.0,
+ disable_walker_contacts=False,
+ enable_field_box=True,
+ terminate_on_goal=False,
+ walker_type=dm_soccer.WalkerType.BOXHEAD)
+
+# Retrieves action_specs for all 4 players.
+action_specs = env.action_spec()
+
+# Step through the environment for one episode with random actions.
+timestep = env.reset()
+while not timestep.last():
+ actions = []
+ for action_spec in action_specs:
+ action = np.random.uniform(
+ action_spec.minimum, action_spec.maximum, size=action_spec.shape)
+ actions.append(action)
+ timestep = env.step(actions)
+
+ for i in range(len(action_specs)):
+ print(
+ "Player {}: reward = {}, discount = {}, observations = {}.".format(
+ i, timestep.reward[i], timestep.discount, timestep.observation[i]))
+```
+
+## Rewards
+
+The environment provides a reward of +1 to each player when their team scores a
+goal, -1 when their team concedes a goal, or 0 if neither team scored on the
+current timestep.
+
+In addition to the sparse reward returned the environment, the player
+observations also contain various environment statistics that may be used to
+derive custom per-player shaping rewards. See `environment.observation_spec()`
+for the additional statistics available to the agents.
+
+## Episode terminations
+
+If `terminate_on_goal` is set to `True`, episodes will terminate immediately
+with a discount factor of 0 when either side scores a goal or if the
+`time_limit` elapsed. If neither team scores within this time then the episode
+will terminate with a discount factor of `1.0`.
+
+If `terminate_on_goal` is set to `False`, players and ball positions are
+randomly initialized if either team scores a goal. Episodes always terminate
+after `time_limit` with a discount factor of `1.0`.
+
+## Environment Viewer
+
+To visualize an example environment instance using the `dm_control` interactive
+viewer, execute `dm_control/locomotion/soccer/explore.py`.
+
+[boxhead]: http://arxiv.org/abs/1902.07151
+[humanoid]: https://arxiv.org/abs/2105.12196
diff --git a/dm_control/locomotion/soccer/__init__.py b/dm_control/locomotion/soccer/__init__.py
new file mode 100644
index 00000000..fc6f59af
--- /dev/null
+++ b/dm_control/locomotion/soccer/__init__.py
@@ -0,0 +1,152 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Multi-agent MuJoCo soccer environment."""
+
+import enum
+
+from dm_control import composer
+from dm_control.locomotion import walkers
+from dm_control.locomotion.soccer.boxhead import BoxHead
+from dm_control.locomotion.soccer.humanoid import Humanoid
+from dm_control.locomotion.soccer.initializers import Initializer
+from dm_control.locomotion.soccer.initializers import UniformInitializer
+from dm_control.locomotion.soccer.observables import CoreObservablesAdder
+from dm_control.locomotion.soccer.observables import InterceptionObservablesAdder
+from dm_control.locomotion.soccer.observables import MultiObservablesAdder
+from dm_control.locomotion.soccer.observables import ObservablesAdder
+from dm_control.locomotion.soccer.pitch import MINI_FOOTBALL_GOAL_SIZE
+from dm_control.locomotion.soccer.pitch import MINI_FOOTBALL_MAX_AREA_PER_HUMANOID
+from dm_control.locomotion.soccer.pitch import MINI_FOOTBALL_MIN_AREA_PER_HUMANOID
+from dm_control.locomotion.soccer.pitch import Pitch
+from dm_control.locomotion.soccer.pitch import RandomizedPitch
+from dm_control.locomotion.soccer.soccer_ball import regulation_soccer_ball
+from dm_control.locomotion.soccer.soccer_ball import SoccerBall
+from dm_control.locomotion.soccer.task import MultiturnTask
+from dm_control.locomotion.soccer.task import Task
+from dm_control.locomotion.soccer.team import Player
+from dm_control.locomotion.soccer.team import RGBA_BLUE
+from dm_control.locomotion.soccer.team import RGBA_RED
+from dm_control.locomotion.soccer.team import Team
+from dm_control.locomotion.walkers.initializers import mocap
+import numpy as np
+
+
+class WalkerType(enum.Enum):
+ BOXHEAD = 0
+ ANT = 1
+ HUMANOID = 2
+
+
+def _make_walker(name, walker_id, marker_rgba, walker_type=WalkerType.BOXHEAD):
+ """Construct a BoxHead walker."""
+ if walker_type == WalkerType.BOXHEAD:
+ return BoxHead(
+ name=name,
+ walker_id=walker_id,
+ marker_rgba=marker_rgba,
+ )
+ if walker_type == WalkerType.ANT:
+ return walkers.Ant(name=name, marker_rgba=marker_rgba)
+ if walker_type == WalkerType.HUMANOID:
+ initializer = mocap.CMUMocapInitializer()
+ return Humanoid(
+ name=name,
+ marker_rgba=marker_rgba,
+ walker_id=walker_id,
+ visual=Humanoid.Visual.JERSEY,
+ initializer=initializer)
+ raise ValueError("Unrecognized walker type: %s" % walker_type)
+
+
+def _make_players(team_size, walker_type):
+ """Construct home and away teams each of `team_size` players."""
+ home_players = []
+ away_players = []
+ for i in range(team_size):
+ home_walker = _make_walker("home%d" % i, i, RGBA_BLUE, walker_type)
+ home_players.append(Player(Team.HOME, home_walker))
+
+ away_walker = _make_walker("away%d" % i, i, RGBA_RED, walker_type)
+ away_players.append(Player(Team.AWAY, away_walker))
+ return home_players + away_players
+
+
+def _area_to_size(area, aspect_ratio=0.75):
+ """Convert from area and aspect_ratio to (width, height)."""
+ return np.sqrt([area / aspect_ratio, area * aspect_ratio]) / 2.
+
+
+def load(team_size,
+ time_limit=45.,
+ random_state=None,
+ disable_walker_contacts=False,
+ enable_field_box=False,
+ keep_aspect_ratio=False,
+ terminate_on_goal=True,
+ walker_type=WalkerType.BOXHEAD):
+ """Construct `team_size`-vs-`team_size` soccer environment.
+
+ Args:
+ team_size: Integer, the number of players per team. Must be between 1 and
+ 11.
+ time_limit: Float, the maximum duration of each episode in seconds.
+ random_state: (optional) an int seed or `np.random.RandomState` instance.
+ disable_walker_contacts: (optional) if `True`, disable physical contacts
+ between walkers.
+ enable_field_box: (optional) if `True`, enable physical bounding box for
+ the soccer ball (but not the players).
+ keep_aspect_ratio: (optional) if `True`, maintain constant pitch aspect
+ ratio.
+ terminate_on_goal: (optional) if `False`, continuous game play across
+ scoring events.
+ walker_type: the type of walker to instantiate in the environment.
+
+ Returns:
+ A `composer.Environment` instance.
+
+ Raises:
+ ValueError: If `team_size` is not between 1 and 11.
+ ValueError: If `walker_type` is not recognized.
+ """
+ goal_size = None
+ min_size = (32, 24)
+ max_size = (48, 36)
+ ball = SoccerBall()
+
+ if walker_type == WalkerType.HUMANOID:
+ goal_size = MINI_FOOTBALL_GOAL_SIZE
+ num_walkers = team_size * 2
+ min_size = _area_to_size(MINI_FOOTBALL_MIN_AREA_PER_HUMANOID * num_walkers)
+ max_size = _area_to_size(MINI_FOOTBALL_MAX_AREA_PER_HUMANOID * num_walkers)
+ ball = regulation_soccer_ball()
+
+ task_factory = Task
+ if not terminate_on_goal:
+ task_factory = MultiturnTask
+
+ return composer.Environment(
+ task=task_factory(
+ players=_make_players(team_size, walker_type),
+ arena=RandomizedPitch(
+ min_size=min_size,
+ max_size=max_size,
+ keep_aspect_ratio=keep_aspect_ratio,
+ field_box=enable_field_box,
+ goal_size=goal_size),
+ ball=ball,
+ disable_walker_contacts=disable_walker_contacts),
+ time_limit=time_limit,
+ random_state=random_state)
diff --git a/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml b/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml
new file mode 100644
index 00000000..50961555
--- /dev/null
+++ b/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml
@@ -0,0 +1,49 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/00.png b/dm_control/locomotion/soccer/assets/boxhead/digits/00.png
new file mode 100644
index 00000000..b5fde93c
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/00.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/01.png b/dm_control/locomotion/soccer/assets/boxhead/digits/01.png
new file mode 100644
index 00000000..f52bae48
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/01.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/02.png b/dm_control/locomotion/soccer/assets/boxhead/digits/02.png
new file mode 100644
index 00000000..bec2be6c
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/02.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/03.png b/dm_control/locomotion/soccer/assets/boxhead/digits/03.png
new file mode 100644
index 00000000..654dba27
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/03.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/04.png b/dm_control/locomotion/soccer/assets/boxhead/digits/04.png
new file mode 100644
index 00000000..f320726f
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/04.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/05.png b/dm_control/locomotion/soccer/assets/boxhead/digits/05.png
new file mode 100644
index 00000000..64d8cd6c
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/05.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/06.png b/dm_control/locomotion/soccer/assets/boxhead/digits/06.png
new file mode 100644
index 00000000..ed97f8bd
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/06.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/07.png b/dm_control/locomotion/soccer/assets/boxhead/digits/07.png
new file mode 100644
index 00000000..4808c29a
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/07.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/08.png b/dm_control/locomotion/soccer/assets/boxhead/digits/08.png
new file mode 100644
index 00000000..41d09b9b
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/08.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/09.png b/dm_control/locomotion/soccer/assets/boxhead/digits/09.png
new file mode 100644
index 00000000..649723c7
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/09.png differ
diff --git a/dm_control/locomotion/soccer/assets/boxhead/digits/10.png b/dm_control/locomotion/soccer/assets/boxhead/digits/10.png
new file mode 100644
index 00000000..d2cc0d08
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/boxhead/digits/10.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_01.png b/dm_control/locomotion/soccer/assets/humanoid/B_01.png
new file mode 100644
index 00000000..9663c773
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_01.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_02.png b/dm_control/locomotion/soccer/assets/humanoid/B_02.png
new file mode 100644
index 00000000..00de3d7e
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_02.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_03.png b/dm_control/locomotion/soccer/assets/humanoid/B_03.png
new file mode 100644
index 00000000..790928ea
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_03.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_04.png b/dm_control/locomotion/soccer/assets/humanoid/B_04.png
new file mode 100644
index 00000000..a156caf6
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_04.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_05.png b/dm_control/locomotion/soccer/assets/humanoid/B_05.png
new file mode 100644
index 00000000..b4bbafa1
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_05.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_06.png b/dm_control/locomotion/soccer/assets/humanoid/B_06.png
new file mode 100644
index 00000000..3f927e35
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_06.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_07.png b/dm_control/locomotion/soccer/assets/humanoid/B_07.png
new file mode 100644
index 00000000..9fec0147
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_07.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_08.png b/dm_control/locomotion/soccer/assets/humanoid/B_08.png
new file mode 100644
index 00000000..96415721
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_08.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_09.png b/dm_control/locomotion/soccer/assets/humanoid/B_09.png
new file mode 100644
index 00000000..d2beedaa
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_09.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_10.png b/dm_control/locomotion/soccer/assets/humanoid/B_10.png
new file mode 100644
index 00000000..1886517e
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_10.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/B_11.png b/dm_control/locomotion/soccer/assets/humanoid/B_11.png
new file mode 100644
index 00000000..7547b89a
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/B_11.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_01.png b/dm_control/locomotion/soccer/assets/humanoid/R_01.png
new file mode 100644
index 00000000..075c5262
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_01.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_02.png b/dm_control/locomotion/soccer/assets/humanoid/R_02.png
new file mode 100644
index 00000000..0cff6c2e
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_02.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_03.png b/dm_control/locomotion/soccer/assets/humanoid/R_03.png
new file mode 100644
index 00000000..693e1c92
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_03.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_04.png b/dm_control/locomotion/soccer/assets/humanoid/R_04.png
new file mode 100644
index 00000000..22597737
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_04.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_05.png b/dm_control/locomotion/soccer/assets/humanoid/R_05.png
new file mode 100644
index 00000000..b785f5d0
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_05.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_06.png b/dm_control/locomotion/soccer/assets/humanoid/R_06.png
new file mode 100644
index 00000000..67bfa3b3
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_06.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_07.png b/dm_control/locomotion/soccer/assets/humanoid/R_07.png
new file mode 100644
index 00000000..2dcf4f31
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_07.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_08.png b/dm_control/locomotion/soccer/assets/humanoid/R_08.png
new file mode 100644
index 00000000..3b8841d3
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_08.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_09.png b/dm_control/locomotion/soccer/assets/humanoid/R_09.png
new file mode 100644
index 00000000..e7b6f488
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_09.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_10.png b/dm_control/locomotion/soccer/assets/humanoid/R_10.png
new file mode 100644
index 00000000..1b386892
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_10.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/R_11.png b/dm_control/locomotion/soccer/assets/humanoid/R_11.png
new file mode 100644
index 00000000..f130a5eb
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/R_11.png differ
diff --git a/dm_control/locomotion/soccer/assets/humanoid/jersey.skn b/dm_control/locomotion/soccer/assets/humanoid/jersey.skn
new file mode 100644
index 00000000..fca6d8e4
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/humanoid/jersey.skn differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_l.png b/dm_control/locomotion/soccer/assets/pitch/pitch_l.png
new file mode 100644
index 00000000..4bda0d07
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_l.png differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_m.png b/dm_control/locomotion/soccer/assets/pitch/pitch_m.png
new file mode 100644
index 00000000..e0f103b5
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_m.png differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_l.png b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_l.png
new file mode 100644
index 00000000..0523153b
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_l.png differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_m.png b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_m.png
new file mode 100644
index 00000000..f83783bf
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_m.png differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_s.png b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_s.png
new file mode 100644
index 00000000..7f55902a
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_s.png differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_xs.png b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_xs.png
new file mode 100644
index 00000000..bd553da0
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_nologo_xs.png differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_s.png b/dm_control/locomotion/soccer/assets/pitch/pitch_s.png
new file mode 100644
index 00000000..d6bb05eb
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_s.png differ
diff --git a/dm_control/locomotion/soccer/assets/pitch/pitch_xs.png b/dm_control/locomotion/soccer/assets/pitch/pitch_xs.png
new file mode 100644
index 00000000..b15b5c2e
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/pitch/pitch_xs.png differ
diff --git a/dm_control/locomotion/soccer/assets/soccer_ball/back.png b/dm_control/locomotion/soccer/assets/soccer_ball/back.png
new file mode 100644
index 00000000..c01b5171
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/soccer_ball/back.png differ
diff --git a/dm_control/locomotion/soccer/assets/soccer_ball/down.png b/dm_control/locomotion/soccer/assets/soccer_ball/down.png
new file mode 100644
index 00000000..49ace5b6
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/soccer_ball/down.png differ
diff --git a/dm_control/locomotion/soccer/assets/soccer_ball/front.png b/dm_control/locomotion/soccer/assets/soccer_ball/front.png
new file mode 100644
index 00000000..c18fbf04
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/soccer_ball/front.png differ
diff --git a/dm_control/locomotion/soccer/assets/soccer_ball/left.png b/dm_control/locomotion/soccer/assets/soccer_ball/left.png
new file mode 100644
index 00000000..d3b14f7b
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/soccer_ball/left.png differ
diff --git a/dm_control/locomotion/soccer/assets/soccer_ball/right.png b/dm_control/locomotion/soccer/assets/soccer_ball/right.png
new file mode 100644
index 00000000..69965a81
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/soccer_ball/right.png differ
diff --git a/dm_control/locomotion/soccer/assets/soccer_ball/up.png b/dm_control/locomotion/soccer/assets/soccer_ball/up.png
new file mode 100644
index 00000000..8c729d70
Binary files /dev/null and b/dm_control/locomotion/soccer/assets/soccer_ball/up.png differ
diff --git a/dm_control/locomotion/soccer/boxhead.py b/dm_control/locomotion/soccer/boxhead.py
new file mode 100644
index 00000000..4f42aba4
--- /dev/null
+++ b/dm_control/locomotion/soccer/boxhead.py
@@ -0,0 +1,350 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Walkers based on an actuated jumping ball."""
+
+import io
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import legacy_base
+import numpy as np
+from PIL import Image
+
+from dm_control.utils import io as resources
+
+_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'boxhead')
+_MAX_WALKER_ID = 10
+_INVALID_WALKER_ID = 'walker_id must be in [0-{}], got: {{}}.'.format(
+ _MAX_WALKER_ID)
+
+
+def _compensate_gravity(physics, body_elements):
+ """Applies Cartesian forces to bodies in order to exactly counteract gravity.
+
+ Note that this will also affect the output of pressure, force, or torque
+ sensors within the kinematic chain leading from the worldbody to the bodies
+ that are being gravity-compensated.
+
+ Args:
+ physics: An `mjcf.Physics` instance to modify.
+ body_elements: An iterable of `mjcf.Element`s specifying the bodies to which
+ gravity compensation will be applied.
+ """
+ gravity = np.hstack([physics.model.opt.gravity, [0, 0, 0]])
+ bodies = physics.bind(body_elements)
+ bodies.xfrc_applied = -gravity * bodies.mass[..., None]
+
+
+def _alpha_blend(foreground, background):
+ """Does alpha compositing of two RGBA images.
+
+ Both inputs must be (..., 4) numpy arrays whose shapes are compatible for
+ broadcasting. They are assumed to contain float RGBA values in [0, 1].
+
+ Args:
+ foreground: foreground RGBA image.
+ background: background RGBA image.
+
+ Returns:
+ A numpy array of shape (..., 4) containing the blended image.
+ """
+ fg, bg = np.broadcast_arrays(foreground, background)
+ fg_rgb = fg[..., :3]
+ fg_a = fg[..., 3:]
+ bg_rgb = bg[..., :3]
+ bg_a = bg[..., 3:]
+ out = np.empty_like(bg)
+ out_a = out[..., 3:]
+ out_rgb = out[..., :3]
+ # https://en.wikipedia.org/wiki/Alpha_compositing#Alpha_blending
+ out_a[:] = fg_a + bg_a * (1. - fg_a)
+ out_rgb[:] = fg_rgb * fg_a + bg_rgb * bg_a * (1. - fg_a)
+ # Avoid division by zero if foreground and background are both transparent.
+ out_rgb[:] = np.where(out_a, out_rgb / out_a, out_rgb)
+ return out
+
+
+def _asset_png_with_background_rgba_bytes(asset_fname, background_rgba):
+ """Decode PNG from asset file and add solid background."""
+
+ # Retrieve PNG image contents as a bytestring, convert to a numpy array.
+ contents = resources.GetResource(os.path.join(_ASSETS_PATH, asset_fname))
+ digit_rgba = np.array(Image.open(io.BytesIO(contents)), dtype=np.double)
+
+ # Add solid background with `background_rgba`.
+ blended = 255. * _alpha_blend(digit_rgba / 255., np.asarray(background_rgba))
+
+ # Encode composite image array to a PNG bytestring.
+ img = Image.fromarray(blended.astype(np.uint8), mode='RGBA')
+ buf = io.BytesIO()
+ img.save(buf, format='PNG')
+ png_encoding = buf.getvalue()
+ buf.close()
+
+ return png_encoding
+
+
+class BoxHeadObservables(legacy_base.WalkerObservables):
+ """BoxHead observables with low-res camera and modulo'd rotational joints."""
+
+ def __init__(self, entity, camera_resolution):
+ self._camera_resolution = camera_resolution
+ super().__init__(entity)
+
+ @composer.observable
+ def egocentric_camera(self):
+ width, height = self._camera_resolution
+ return observable.MJCFCamera(self._entity.egocentric_camera,
+ width=width, height=height)
+
+ @property
+ def proprioception(self):
+ proprioception = super().proprioception
+ if self._entity.observable_camera_joints:
+ return proprioception + [self.camera_joints_pos, self.camera_joints_vel]
+ return proprioception
+
+ @composer.observable
+ def camera_joints_pos(self):
+
+ def _sin(value, random_state):
+ del random_state
+ return np.sin(value)
+
+ def _cos(value, random_state):
+ del random_state
+ return np.cos(value)
+
+ sin_rotation_joints = observable.MJCFFeature(
+ 'qpos', self._entity.observable_camera_joints, corruptor=_sin)
+
+ cos_rotation_joints = observable.MJCFFeature(
+ 'qpos', self._entity.observable_camera_joints, corruptor=_cos)
+
+ def _camera_joints(physics):
+ return np.concatenate([
+ sin_rotation_joints(physics),
+ cos_rotation_joints(physics)
+ ], -1)
+
+ return observable.Generic(_camera_joints)
+
+ @composer.observable
+ def camera_joints_vel(self):
+ return observable.MJCFFeature(
+ 'qvel', self._entity.observable_camera_joints)
+
+
+class BoxHead(legacy_base.Walker):
+ """A rollable and jumpable ball with a head."""
+
+ def _build(self,
+ name='walker',
+ marker_rgba=None,
+ camera_control=False,
+ camera_resolution=(28, 28),
+ roll_gear=-60,
+ steer_gear=55,
+ walker_id=None,
+ initializer=None):
+ """Build a BoxHead.
+
+ Args:
+ name: name of the walker.
+ marker_rgba: RGBA value set to walker.marker_geoms to distinguish between
+ walkers (in multi-agent setting).
+ camera_control: If `True`, the walker exposes two additional actuated
+ degrees of freedom to control the egocentric camera height and tilt.
+ camera_resolution: egocentric camera rendering resolution.
+ roll_gear: gear determining forward acceleration.
+ steer_gear: gear determining steering (spinning) torque.
+ walker_id: (Optional) An integer in [0-10], this number will be shown on
+ the walker's head. Defaults to `None` which does not show any number.
+ initializer: (Optional) A `WalkerInitializer` object.
+
+ Raises:
+ ValueError: if received invalid walker_id.
+ """
+ super()._build(initializer=initializer)
+ xml_path = os.path.join(_ASSETS_PATH, 'boxhead.xml')
+ self._mjcf_root = mjcf.from_xml_string(resources.GetResource(xml_path, 'r'))
+ if name:
+ self._mjcf_root.model = name
+
+ if walker_id is not None and not 0 <= walker_id <= _MAX_WALKER_ID:
+ raise ValueError(_INVALID_WALKER_ID.format(walker_id))
+
+ self._walker_id = walker_id
+ if walker_id is not None:
+ png_bytes = _asset_png_with_background_rgba_bytes(
+ 'digits/%02d.png' % walker_id, marker_rgba)
+ head_texture = self._mjcf_root.asset.add(
+ 'texture',
+ name='head_texture',
+ type='2d',
+ file=mjcf.Asset(png_bytes, '.png'))
+ head_material = self._mjcf_root.asset.add(
+ 'material', name='head_material', texture=head_texture)
+ self._mjcf_root.find('geom', 'head').material = head_material
+ self._mjcf_root.find('geom', 'head').rgba = None
+
+ self._mjcf_root.find('geom', 'top_down_cam_box').material = head_material
+ self._mjcf_root.find('geom', 'top_down_cam_box').rgba = None
+
+ self._body_texture = self._mjcf_root.asset.add(
+ 'texture',
+ name='ball_body',
+ type='cube',
+ builtin='checker',
+ rgb1=marker_rgba[:-1] if marker_rgba else '.4 .4 .4',
+ rgb2='.8 .8 .8',
+ width='100',
+ height='100')
+ self._body_material = self._mjcf_root.asset.add(
+ 'material', name='ball_body', texture=self._body_texture)
+ self._mjcf_root.find('geom', 'shell').material = self._body_material
+
+ # Set corresponding marker color if specified.
+ if marker_rgba is not None:
+ for geom in self.marker_geoms:
+ geom.set_attributes(rgba=marker_rgba)
+
+ self._root_joints = None
+ self._camera_control = camera_control
+ self._camera_resolution = camera_resolution
+ if not camera_control:
+ for name in ('camera_pitch', 'camera_yaw'):
+ self._mjcf_root.find('actuator', name).remove()
+ self._mjcf_root.find('joint', name).remove()
+ self._roll_gear = roll_gear
+ self._steer_gear = steer_gear
+ self._mjcf_root.find('actuator', 'roll').gear[0] = self._roll_gear
+ self._mjcf_root.find('actuator', 'steer').gear[0] = self._steer_gear
+
+ # Initialize previous action.
+ self._prev_action = np.zeros(shape=self.action_spec.shape,
+ dtype=self.action_spec.dtype)
+
+ def _build_observables(self):
+ return BoxHeadObservables(self, camera_resolution=self._camera_resolution)
+
+ @property
+ def marker_geoms(self):
+ geoms = [
+ self._mjcf_root.find('geom', 'arm_l'),
+ self._mjcf_root.find('geom', 'arm_r'),
+ self._mjcf_root.find('geom', 'eye_l'),
+ self._mjcf_root.find('geom', 'eye_r'),
+ ]
+ if self._walker_id is None:
+ geoms.append(self._mjcf_root.find('geom', 'head'))
+ return geoms
+
+ def create_root_joints(self, attachment_frame):
+ root_class = self._mjcf_root.find('default', 'root')
+ root_x = attachment_frame.add(
+ 'joint', name='root_x', type='slide', axis=[1, 0, 0], dclass=root_class)
+ root_y = attachment_frame.add(
+ 'joint', name='root_y', type='slide', axis=[0, 1, 0], dclass=root_class)
+ root_z = attachment_frame.add(
+ 'joint', name='root_z', type='slide', axis=[0, 0, 1], dclass=root_class)
+ self._root_joints = [root_x, root_y, root_z]
+
+ def set_pose(self, physics, position=None, quaternion=None):
+ if position is not None:
+ if self._root_joints is not None:
+ physics.bind(self._root_joints).qpos = position
+ else:
+ super().set_pose(physics, position, quaternion=None)
+ physics.bind(self._mjcf_root.find_all('joint')).qpos = 0.
+ if quaternion is not None:
+ # This walker can only rotate along the z-axis, so we extract only that
+ # component from the quaternion.
+ z_angle = np.arctan2(
+ 2 * (quaternion[0] * quaternion[3] + quaternion[1] * quaternion[2]),
+ 1 - 2 * (quaternion[2] ** 2 + quaternion[3] ** 2))
+ physics.bind(self._mjcf_root.find('joint', 'steer')).qpos = z_angle
+
+ def set_velocity(self, physics, velocity=None, angular_velocity=None):
+ if velocity is not None:
+ if self._root_joints is not None:
+ physics.bind(self._root_joints).qvel = velocity
+
+ if angular_velocity is not None:
+ # This walker can only rotate along the z-axis, so we extract only that
+ # component from the angular_velocity.
+ steer_joint = self._mjcf_root.find('joint', 'steer')
+ if isinstance(angular_velocity, float):
+ z_velocity = angular_velocity
+ else:
+ z_velocity = angular_velocity[2]
+ physics.bind(steer_joint).qvel = z_velocity
+
+ def initialize_episode(self, physics, random_state):
+ if self._camera_control:
+ _compensate_gravity(physics,
+ self._mjcf_root.find('body', 'egocentric_camera'))
+ self._prev_action = np.zeros(shape=self.action_spec.shape,
+ dtype=self.action_spec.dtype)
+
+ def apply_action(self, physics, action, random_state):
+ super().apply_action(physics, action, random_state)
+
+ # Updates previous action.
+ self._prev_action[:] = action
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @composer.cached_property
+ def actuators(self):
+ return self._mjcf_root.find_all('actuator')
+
+ @composer.cached_property
+ def root_body(self):
+ return self._mjcf_root.find('body', 'head_body')
+
+ @composer.cached_property
+ def end_effectors(self):
+ return (self._mjcf_root.find('body', 'head_body'),)
+
+ @composer.cached_property
+ def observable_joints(self):
+ return (self._mjcf_root.find('joint', 'kick'),)
+
+ @composer.cached_property
+ def observable_camera_joints(self):
+ if self._camera_control:
+ return (
+ self._mjcf_root.find('joint', 'camera_yaw'),
+ self._mjcf_root.find('joint', 'camera_pitch'),
+ )
+ return ()
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ return (self._mjcf_root.find('geom', 'shell'),)
+
+ @property
+ def prev_action(self):
+ return self._prev_action
diff --git a/dm_control/locomotion/soccer/boxhead_test.py b/dm_control/locomotion/soccer/boxhead_test.py
new file mode 100644
index 00000000..53c06e7c
--- /dev/null
+++ b/dm_control/locomotion/soccer/boxhead_test.py
@@ -0,0 +1,43 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.boxhead."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.locomotion.soccer import boxhead
+
+
+class BoxheadTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ dict(camera_control=True, walker_id=None),
+ dict(camera_control=False, walker_id=None),
+ dict(camera_control=True, walker_id=0),
+ dict(camera_control=False, walker_id=10))
+ def test_instantiation(self, camera_control, walker_id):
+ boxhead.BoxHead(marker_rgba=[.8, .1, .1, 1.],
+ camera_control=camera_control,
+ walker_id=walker_id)
+
+ @parameterized.parameters(-1, 11)
+ def test_invalid_walker_id(self, walker_id):
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, boxhead._INVALID_WALKER_ID.format(walker_id)):
+ boxhead.BoxHead(walker_id=walker_id)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/soccer/camera.py b/dm_control/locomotion/soccer/camera.py
new file mode 100644
index 00000000..6a3f1f85
--- /dev/null
+++ b/dm_control/locomotion/soccer/camera.py
@@ -0,0 +1,119 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Cameras for recording soccer videos."""
+
+from dm_control.mujoco import engine
+import numpy as np
+
+
+class MultiplayerTrackingCamera:
+ """Camera that smoothly tracks multiple entities."""
+
+ def __init__(
+ self,
+ min_distance,
+ distance_factor,
+ smoothing_update_speed,
+ azimuth=90,
+ elevation=-45,
+ width=1920,
+ height=1080,
+ ):
+ """Construct a new MultiplayerTrackingcamera.
+
+ The target lookat point is the centroid of all tracked entities.
+ Target camera distance is set to min_distance + distance_factor * d_max,
+ where d_max is the maximum distance of any entity to the lookat point.
+
+ Args:
+ min_distance: minimum camera distance.
+ distance_factor: camera distance multiplier (see above).
+ smoothing_update_speed: exponential filter parameter to smooth camera
+ movement. 1 means no filter; smaller values mean less change per step.
+ azimuth: constant azimuth to use for camera.
+ elevation: constant elevation to use for camera.
+ width: width to use for rendered video.
+ height: height to use for rendered video.
+ """
+ self._min_distance = min_distance
+ self._distance_factor = distance_factor
+ if smoothing_update_speed < 0 or smoothing_update_speed > 1:
+ raise ValueError("Filter speed must be in range [0, 1].")
+ self._smoothing_update_speed = smoothing_update_speed
+ self._azimuth = azimuth
+ self._elevation = elevation
+ self._width = width
+ self._height = height
+ self._camera = None
+
+ @property
+ def camera(self):
+ return self._camera
+
+ def render(self):
+ """Render the current frame."""
+ if self._camera is None:
+ raise ValueError(
+ "Camera has not been initialized yet."
+ " render can only be called after physics has been compiled."
+ )
+ return self._camera.render()
+
+ def after_compile(self, physics):
+ """Instantiate the camera and ensure rendering buffer is large enough."""
+ buffer_height = max(self._height, physics.model.vis.global_.offheight)
+ buffer_width = max(self._width, physics.model.vis.global_.offwidth)
+ physics.model.vis.global_.offheight = buffer_height
+ physics.model.vis.global_.offwidth = buffer_width
+ self._camera = engine.MovableCamera(
+ physics, height=self._height, width=self._width)
+
+ def _get_target_camera_pose(self, entity_positions):
+ """Returns the pose that the camera should be pulled toward.
+
+ Args:
+ entity_positions: list of numpy arrays representing current positions of
+ the entities to be tracked.
+ Returns: mujoco.engine.Pose representing the target camera pose.
+ """
+ stacked_positions = np.stack(entity_positions)
+ centroid = np.mean(stacked_positions, axis=0)
+ radii = np.linalg.norm(stacked_positions - centroid, axis=1)
+ assert len(radii) == len(entity_positions)
+ camera_distance = self._min_distance + self._distance_factor * np.max(radii)
+ return engine.Pose(
+ lookat=centroid,
+ distance=camera_distance,
+ azimuth=self._azimuth,
+ elevation=self._elevation,
+ )
+
+ def initialize_episode(self, entity_positions):
+ """Begin the episode with the camera set to its target pose."""
+ target_pose = self._get_target_camera_pose(entity_positions)
+ self._camera.set_pose(*target_pose)
+
+ def after_step(self, entity_positions):
+ """Move camera toward its target poses."""
+ target_pose = self._get_target_camera_pose(entity_positions)
+ cur_pose = self._camera.get_pose()
+ smoothing_update_speed = self._smoothing_update_speed
+ filtered_pose = [
+ target_val * smoothing_update_speed + \
+ current_val * (1 - smoothing_update_speed)
+ for target_val, current_val in zip(target_pose, cur_pose)
+ ]
+ self._camera.set_pose(*filtered_pose)
diff --git a/dm_control/locomotion/soccer/explore.py b/dm_control/locomotion/soccer/explore.py
new file mode 100644
index 00000000..20521201
--- /dev/null
+++ b/dm_control/locomotion/soccer/explore.py
@@ -0,0 +1,55 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Interactive viewer for MuJoCo soccer environment."""
+
+import functools
+from absl import app
+from absl import flags
+from dm_control.locomotion import soccer
+from dm_control import viewer
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_enum("walker_type", "BOXHEAD", ["BOXHEAD", "ANT", "HUMANOID"],
+ "The type of walker to explore with.")
+flags.DEFINE_bool(
+ "enable_field_box", True,
+ "If `True`, enable physical bounding box enclosing the ball"
+ " (but not the players).")
+flags.DEFINE_bool("disable_walker_contacts", False,
+ "If `True`, disable walker-walker contacts.")
+flags.DEFINE_bool(
+ "terminate_on_goal", False,
+ "If `True`, the episode terminates upon a goal being scored.")
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError("Too many command-line arguments.")
+
+ viewer.launch(
+ environment_loader=functools.partial(
+ soccer.load,
+ team_size=2,
+ walker_type=soccer.WalkerType[FLAGS.walker_type],
+ disable_walker_contacts=FLAGS.disable_walker_contacts,
+ enable_field_box=FLAGS.enable_field_box,
+ keep_aspect_ratio=True,
+ terminate_on_goal=FLAGS.terminate_on_goal))
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/dm_control/locomotion/soccer/humanoid.py b/dm_control/locomotion/soccer/humanoid.py
new file mode 100644
index 00000000..e066b5b4
--- /dev/null
+++ b/dm_control/locomotion/soccer/humanoid.py
@@ -0,0 +1,226 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Walkers based on an actuated jumping ball."""
+
+import enum
+import os
+
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+
+
+_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'humanoid')
+_MAX_WALKER_ID = 10
+_INVALID_WALKER_ID = 'walker_id must be in [0-{}], got: {{}}.'.format(
+ _MAX_WALKER_ID)
+
+_INTERIOR_GEOMS = frozenset({
+ 'lhipjoint', 'rhipjoint', 'lfemur', 'lowerback', 'upperback', 'rclavicle',
+ 'lclavicle', 'thorax', 'lhumerus', 'root_geom', 'lowerneck', 'rhumerus',
+ 'rfemur'
+})
+
+
+def _add_visual_only_geoms(mjcf_root):
+ """Introduce visual only geoms to complement the `JERSEY` visual."""
+ lowerneck = mjcf_root.find('body', 'lowerneck')
+ neck_offset = 0.066 - 0.0452401
+ lowerneck.add(
+ 'geom',
+ name='halfneck',
+ # shrink neck radius from 0.06 to 0.05 else it pokes through shirt
+ size=(0.05, 0.02279225 - neck_offset),
+ pos=(-0.00165071, 0.0452401 + neck_offset, 0.00534359),
+ quat=(0.66437, 0.746906, 0.027253, 0),
+ mass=0.,
+ contype=0,
+ conaffinity=0,
+ rgba=(.7, .5, .3, 1))
+ lhumerus = mjcf_root.find('body', 'lhumerus')
+ humerus_offset = 0.20 - 0.138421
+ lhumerus.add(
+ 'geom',
+ name='lelbow',
+ size=(0.035, 0.1245789 - humerus_offset),
+ pos=(0.0, -0.138421 - humerus_offset, 0.0),
+ quat=(0.612372, -0.612372, 0.353553, 0.353553),
+ mass=0.,
+ contype=0,
+ conaffinity=0,
+ rgba=(.7, .5, .3, 1))
+ rhumerus = mjcf_root.find('body', 'rhumerus')
+ humerus_offset = 0.20 - 0.138421
+ rhumerus.add(
+ 'geom',
+ name='relbow',
+ size=(0.035, 0.1245789 - humerus_offset),
+ pos=(0.0, -0.138421 - humerus_offset, 0.0),
+ quat=(0.612372, -0.612372, -0.353553, -0.353553),
+ mass=0.,
+ contype=0,
+ conaffinity=0,
+ rgba=(.7, .5, .3, 1))
+ lfemur = mjcf_root.find('body', 'lfemur')
+ femur_offset = 0.384 - 0.202473
+ lfemur.add(
+ 'geom',
+ name='lknee',
+ # shrink knee radius from 0.06 to 0.055 else it pokes through short
+ size=(0.055, 0.1822257 - femur_offset),
+ pos=(-5.0684e-08, -0.202473 - femur_offset, 0),
+ quat=(0.696364, -0.696364, -0.122788, -0.122788),
+ mass=0.,
+ contype=0,
+ conaffinity=0,
+ rgba=(.7, .5, .3, 1))
+ rfemur = mjcf_root.find('body', 'rfemur')
+ femur_offset = 0.384 - 0.202473
+ rfemur.add(
+ 'geom',
+ name='rknee',
+ # shrink knee radius from 0.06 to 0.055 else it pokes through short
+ size=(0.055, 0.1822257 - femur_offset),
+ pos=(-5.0684e-08, -0.202473 - femur_offset, 0),
+ quat=(0.696364, -0.696364, 0.122788, 0.122788),
+ mass=0.,
+ contype=0,
+ conaffinity=0,
+ rgba=(.7, .5, .3, 1))
+
+
+class Humanoid(cmu_humanoid.CMUHumanoidPositionControlled):
+ """A CMU humanoid walker specialised visually for soccer."""
+
+ class Visual(enum.Enum):
+ GEOM = 1
+ JERSEY = 2
+
+ def _build(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
+ visual,
+ marker_rgba,
+ walker_id=None,
+ initializer=None,
+ name='walker'):
+ """Build a soccer-specific Humanoid walker."""
+ if not isinstance(visual, Humanoid.Visual):
+ raise ValueError('`visual` must be one of `Humanoid.Visual`.')
+
+ if len(marker_rgba) != 4:
+ raise ValueError('`marker_rgba` must be a sequence of length 4.')
+
+ if walker_id is None and visual != Humanoid.Visual.GEOM:
+ raise ValueError(
+ '`walker_id` must be set unless `visual` is set to `Visual.GEOM`.')
+
+ if walker_id is not None and not 0 <= walker_id <= _MAX_WALKER_ID:
+ raise ValueError(_INVALID_WALKER_ID.format(walker_id))
+
+ if visual == Humanoid.Visual.JERSEY:
+ team = 'R' if marker_rgba[0] > marker_rgba[2] else 'B'
+ marker_rgba = None # disable geom coloring for None geom visual.
+ else:
+ marker_rgba[-1] = .7
+
+ super(Humanoid, self)._build(
+ marker_rgba=marker_rgba,
+ initializer=initializer,
+ include_face=True)
+
+ self._mjcf_root.model = name
+
+ # Changes to humanoid geoms for visual improvements.
+ # Hands: hide hand geoms and add slightly larger visual geoms.
+ for hand_name in ['lhand', 'rhand']:
+ hand = self._mjcf_root.find('body', hand_name)
+ for geom in hand.find_all('geom'):
+ geom.rgba = (0, 0, 0, 0)
+ if geom.name == hand_name:
+ geom_size = geom.size * 1.3 # Palm rescaling.
+ else:
+ geom_size = geom.size * 1.5 # Finger rescaling.
+ geom.parent.add(
+ 'geom',
+ name=geom.name + '_visual',
+ type=geom.type,
+ quat=geom.quat,
+ mass=0,
+ contype=0,
+ conaffinity=0,
+ size=geom_size,
+ pos=geom.pos * 1.5)
+
+ # Lighting: remove tracking light as we have multiple walkers in the scene.
+ tracking_light = self._mjcf_root.find('light', 'tracking_light')
+ tracking_light.remove()
+
+ if visual == Humanoid.Visual.JERSEY:
+ shirt_number = walker_id + 1
+ self._mjcf_root.asset.add(
+ 'texture',
+ name='skin',
+ type='2d',
+ file=os.path.join(_ASSETS_PATH, f'{team}_{walker_id + 1:02d}.png'))
+ self._mjcf_root.asset.add('material', name='skin', texture='skin')
+ self._mjcf_root.asset.add(
+ 'skin',
+ name='skin',
+ file=os.path.join(_ASSETS_PATH, 'jersey.skn'),
+ material='skin')
+ for geom in self._mjcf_root.find_all('geom'):
+ if geom.name in _INTERIOR_GEOMS:
+ geom.rgba = (0.0, 0.0, 0.0, 0.0)
+ _add_visual_only_geoms(self._mjcf_root)
+
+ # Initialize previous action.
+ self._prev_action = np.zeros(shape=self.action_spec.shape,
+ dtype=self.action_spec.dtype)
+
+ @property
+ def marker_geoms(self):
+ """Returns a sequence of marker geoms to be colored visually."""
+ marker_geoms = []
+
+ face = self._mjcf_root.find('geom', 'face')
+ if face is not None:
+ marker_geoms.append(face)
+
+ marker_geoms += self._mjcf_root.find('body', 'rfoot').find_all('geom')
+ marker_geoms += self._mjcf_root.find('body', 'lfoot').find_all('geom')
+ return marker_geoms + [
+ self._mjcf_root.find('geom', 'lowerneck'),
+ self._mjcf_root.find('geom', 'lclavicle'),
+ self._mjcf_root.find('geom', 'rclavicle'),
+ self._mjcf_root.find('geom', 'thorax'),
+ self._mjcf_root.find('geom', 'upperback'),
+ self._mjcf_root.find('geom', 'lowerback'),
+ self._mjcf_root.find('geom', 'rfemur'),
+ self._mjcf_root.find('geom', 'lfemur'),
+ self._mjcf_root.find('geom', 'root_geom'),
+ ]
+
+ def initialize_episode(self, physics, random_state):
+ self._prev_action = np.zeros(shape=self.action_spec.shape,
+ dtype=self.action_spec.dtype)
+
+ def apply_action(self, physics, action, random_state):
+ super().apply_action(physics, action, random_state)
+
+ # Updates previous action.
+ self._prev_action[:] = action
+
+ @property
+ def prev_action(self):
+ return self._prev_action
diff --git a/dm_control/locomotion/soccer/humanoid_test.py b/dm_control/locomotion/soccer/humanoid_test.py
new file mode 100644
index 00000000..cf36f858
--- /dev/null
+++ b/dm_control/locomotion/soccer/humanoid_test.py
@@ -0,0 +1,49 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for dm_control.locomotion.soccer.Humanoid."""
+
+import itertools
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.locomotion.soccer import humanoid
+from dm_control.locomotion.soccer import team
+
+
+class HumanoidTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ itertools.product(
+ humanoid.Humanoid.Visual,
+ (team.RGBA_RED, team.RGBA_BLUE),
+ (None, 0, 10),
+ ))
+ def test_instantiation(self, visual, marker_rgba, walker_id):
+ if visual != humanoid.Humanoid.Visual.GEOM and walker_id is None:
+ self.skipTest('Invalid configuration skipped.')
+ humanoid.Humanoid(
+ visual=visual, marker_rgba=marker_rgba, walker_id=walker_id)
+
+ @parameterized.parameters(-1, 11)
+ def test_invalid_walker_id(self, walker_id):
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, humanoid._INVALID_WALKER_ID.format(walker_id)):
+ humanoid.Humanoid(
+ visual=humanoid.Humanoid.Visual.JERSEY,
+ walker_id=walker_id,
+ marker_rgba=team.RGBA_BLUE)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/soccer/initializers.py b/dm_control/locomotion/soccer/initializers.py
new file mode 100644
index 00000000..0be4a287
--- /dev/null
+++ b/dm_control/locomotion/soccer/initializers.py
@@ -0,0 +1,126 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Soccer task episode initializers."""
+
+import abc
+import numpy as np
+
+
+_INIT_BALL_Z = 0.5
+_SPAWN_RATIO = 0.6
+
+
+class Initializer(metaclass=abc.ABCMeta):
+
+ @abc.abstractmethod
+ def __call__(self, task, physics, random_state):
+ """Initialize episode for a task."""
+
+
+class UniformInitializer(Initializer):
+ """Uniformly initialize walkers and soccer ball over spawn_range."""
+
+ def __init__(self,
+ spawn_ratio=_SPAWN_RATIO,
+ init_ball_z=_INIT_BALL_Z,
+ max_collision_avoidance_retries=100):
+ self._spawn_ratio = spawn_ratio
+ self._init_ball_z = init_ball_z
+
+ # Lazily initialize geom ids for contact avoidance.
+ self._ball_geom_ids = None
+ self._walker_geom_ids = None
+ self._all_geom_ids = None
+ self._max_retries = max_collision_avoidance_retries
+
+ def _initialize_ball(self, ball, spawn_range, physics, random_state):
+ """Initialize ball in given spawn_range."""
+ if isinstance(spawn_range, np.ndarray):
+ x, y = random_state.uniform(-spawn_range, spawn_range)
+ elif isinstance(spawn_range, (list, tuple)) and len(spawn_range) == 2:
+ x, y = random_state.uniform(spawn_range[0], spawn_range[1])
+ else:
+ raise ValueError(
+ 'Unsupported spawn_range. Must be ndarray or list/tuple of length 2.')
+ ball.set_pose(physics, [x, y, self._init_ball_z])
+ # Note: this method is not always called immediately after `physics.reset()`
+ # so we need to explicitly zero out the velocity.
+ ball.set_velocity(physics, velocity=0., angular_velocity=0.)
+
+ def _initialize_walker(self, walker, spawn_range, physics, random_state):
+ """Uniformly initialize walker in spawn_range."""
+ walker.reinitialize_pose(physics, random_state)
+ x, y = random_state.uniform(-spawn_range, spawn_range)
+ (_, _, z), quat = walker.get_pose(physics)
+ walker.set_pose(physics, [x, y, z], quat)
+ rotation = random_state.uniform(-np.pi, np.pi)
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walker.shift_pose(physics, quaternion=quat)
+ # Note: this method is not always called immediately after `physics.reset()`
+ # so we need to explicitly zero out the velocity.
+ walker.set_velocity(physics, velocity=0., angular_velocity=0.)
+
+ def _initialize_entities(self, task, physics, random_state):
+ spawn_range = np.asarray(task.arena.size) * self._spawn_ratio
+ self._initialize_ball(task.ball, spawn_range, physics, random_state)
+ for player in task.players:
+ self._initialize_walker(player.walker, spawn_range, physics, random_state)
+
+ def _initialize_geom_ids(self, task, physics):
+ self._ball_geom_ids = {physics.bind(task.ball.geom)}
+ self._walker_geom_ids = []
+ for player in task.players:
+ walker_geoms = player.walker.mjcf_model.find_all('geom')
+ self._walker_geom_ids.append(set(physics.bind(walker_geoms).element_id))
+
+ self._all_geom_ids = set(self._ball_geom_ids)
+ for walker_geom_ids in self._walker_geom_ids:
+ self._all_geom_ids |= walker_geom_ids
+
+ def _has_relevant_contact(self, contact, geom_ids):
+ other_geom_ids = self._all_geom_ids - geom_ids
+ if ((contact.geom1 in geom_ids and contact.geom2 in other_geom_ids) or
+ (contact.geom2 in geom_ids and contact.geom1 in other_geom_ids)):
+ return True
+ return False
+
+ def __call__(self, task, physics, random_state):
+ # Initialize geom_ids for collision detection.
+ if not self._all_geom_ids:
+ self._initialize_geom_ids(task, physics)
+
+ num_retries = 0
+ while True:
+ self._initialize_entities(task, physics, random_state)
+
+ should_retry = False
+ physics.forward() # forward physics for contact resolution.
+ for contact in physics.data.contact:
+ if self._has_relevant_contact(contact, self._ball_geom_ids):
+ should_retry = True
+ break
+ for walker_geom_ids in self._walker_geom_ids:
+ if self._has_relevant_contact(contact, walker_geom_ids):
+ should_retry = True
+ break
+
+ if not should_retry:
+ break
+
+ num_retries += 1
+ if num_retries > self._max_retries:
+ raise RuntimeError('UniformInitializer: `max_retries` (%d) exceeded.' %
+ self._max_retries)
diff --git a/dm_control/locomotion/soccer/loader_test.py b/dm_control/locomotion/soccer/loader_test.py
new file mode 100644
index 00000000..5a357fdd
--- /dev/null
+++ b/dm_control/locomotion/soccer/loader_test.py
@@ -0,0 +1,103 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.load."""
+
+
+from absl import logging
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.locomotion import soccer
+import numpy as np
+
+
+class LoadTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("2vs2_nocontacts", 2, True), ("2vs2_contacts", 2, False),
+ ("1vs1_nocontacts", 1, True), ("1vs1_contacts", 1, False))
+ def test_load_env(self, team_size, disable_walker_contacts):
+ env = soccer.load(team_size=team_size, time_limit=2.,
+ disable_walker_contacts=disable_walker_contacts)
+ action_specs = env.action_spec()
+
+ random_state = np.random.RandomState(0)
+ time_step = env.reset()
+ while not time_step.last():
+ actions = []
+ for action_spec in action_specs:
+ action = random_state.uniform(
+ action_spec.minimum, action_spec.maximum, size=action_spec.shape)
+ actions.append(action)
+ time_step = env.step(actions)
+
+ for i in range(len(action_specs)):
+ logging.info(
+ "Player %d: reward = %s, discount = %s, observations = %s.", i,
+ time_step.reward[i], time_step.discount, time_step.observation[i])
+
+ def assertSameObservation(self, expected_observation, actual_observation):
+ self.assertLen(actual_observation, len(expected_observation))
+ for player_id in range(len(expected_observation)):
+ expected_player_observations = expected_observation[player_id]
+ actual_player_observations = actual_observation[player_id]
+ expected_keys = expected_player_observations.keys()
+ actual_keys = actual_player_observations.keys()
+ msg = ("Observation keys differ for player {}.\nExpected: {}.\nActual: {}"
+ .format(player_id, expected_keys, actual_keys))
+ self.assertEqual(expected_keys, actual_keys, msg)
+ for key in expected_player_observations:
+ expected_array = expected_player_observations[key]
+ actual_array = actual_player_observations[key]
+ msg = ("Observation {!r} differs for player {}.\nExpected:\n{}\n"
+ "Actual:\n{}"
+ .format(key, player_id, expected_array, actual_array))
+ np.testing.assert_array_equal(expected_array, actual_array,
+ err_msg=msg)
+
+ @parameterized.parameters(True, False)
+ def test_same_first_observation_if_same_seed(self, disable_walker_contacts):
+ seed = 42
+ timestep_1 = soccer.load(
+ team_size=2,
+ random_state=seed,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ timestep_2 = soccer.load(
+ team_size=2,
+ random_state=seed,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ self.assertSameObservation(timestep_1.observation, timestep_2.observation)
+
+ @parameterized.parameters(True, False)
+ def test_different_first_observation_if_different_seed(
+ self, disable_walker_contacts):
+ timestep_1 = soccer.load(
+ team_size=2,
+ random_state=1,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ timestep_2 = soccer.load(
+ team_size=2,
+ random_state=2,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ try:
+ self.assertSameObservation(timestep_1.observation, timestep_2.observation)
+ except AssertionError:
+ pass
+ else:
+ self.fail("Observations are unexpectedly identical.")
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/locomotion/soccer/observables.py b/dm_control/locomotion/soccer/observables.py
new file mode 100644
index 00000000..1a926daa
--- /dev/null
+++ b/dm_control/locomotion/soccer/observables.py
@@ -0,0 +1,451 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Soccer observables modules."""
+
+import abc
+
+from dm_control.composer.observation import observable as base_observable
+from dm_control.locomotion.soccer import team as team_lib
+import numpy as np
+
+
+class ObservablesAdder(metaclass=abc.ABCMeta):
+ """A callable that adds a set of per-player observables for a task."""
+
+ @abc.abstractmethod
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+
+
+class MultiObservablesAdder(ObservablesAdder):
+ """Applies multiple `ObservablesAdder`s to a soccer task and player."""
+
+ def __init__(self, observables):
+ """Initializes a `MultiObservablesAdder` instance.
+
+ Args:
+ observables: A list of `ObservablesAdder` instances.
+ """
+ self._observables = observables
+
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+ for observable in self._observables:
+ observable(task, player)
+
+
+class CoreObservablesAdder(ObservablesAdder):
+ """Core set of per player observables."""
+
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+ # Enable proprioceptive observables.
+ self._add_player_proprio_observables(player)
+
+ # Add egocentric observations of soccer ball.
+ self._add_player_observables_on_ball(player, task.ball)
+
+ # Add egocentric observations of others.
+ teammate_id = 0
+ opponent_id = 0
+ for other in task.players:
+ if other is player:
+ continue
+ # Infer team prefix for `other` conditioned on `player.team`.
+ if player.team != other.team:
+ prefix = 'opponent_{}'.format(opponent_id)
+ opponent_id += 1
+ else:
+ prefix = 'teammate_{}'.format(teammate_id)
+ teammate_id += 1
+
+ self._add_player_observables_on_other(player, other, prefix)
+
+ self._add_player_arena_observables(player, task.arena)
+
+ # Add per player game statistics.
+ self._add_player_stats_observables(task, player)
+
+ def _add_player_observables_on_other(self, player, other, prefix):
+ """Add observables of another player in this player's egocentric frame.
+
+ Args:
+ player: A `Walker` instance, the player we are adding observables to.
+ other: A `Walker` instance corresponding to a different player.
+ prefix: A string specifying a prefix to apply to the names of observables
+ belonging to `player`.
+ """
+ if player is other:
+ raise ValueError('Cannot add egocentric observables of player on itself.')
+
+ sensors = []
+ for effector in other.walker.end_effectors:
+ name = effector.name + '_' + prefix + '_end_effector'
+ sensors.append(player.walker.mjcf_model.sensor.add(
+ 'framepos', name=name,
+ objtype=effector.tag, objname=effector,
+ reftype='body', refname=player.walker.root_body))
+ def _egocentric_end_effectors_xpos(physics):
+ return np.reshape(physics.bind(sensors).sensordata, -1)
+ # Adds end effectors of the other agents in the player's egocentric frame.
+ name = '{}_ego_end_effectors_pos'.format(prefix)
+ player.walker.obs_on_other[name] = sensors
+ player.walker.observables.add_observable(
+ name,
+ base_observable.Generic(_egocentric_end_effectors_xpos))
+
+ ego_linvel_name = '{}_ego_linear_velocity'.format(prefix)
+ ego_linvel_sensor = player.walker.mjcf_model.sensor.add(
+ 'framelinvel', name=ego_linvel_name,
+ objtype='body', objname=other.walker.root_body,
+ reftype='body', refname=player.walker.root_body)
+ player.walker.obs_on_other[ego_linvel_name] = [ego_linvel_sensor]
+ player.walker.observables.add_observable(
+ ego_linvel_name,
+ base_observable.MJCFFeature('sensordata', ego_linvel_sensor))
+
+ ego_pos_name = '{}_ego_position'.format(prefix)
+ ego_pos_sensor = player.walker.mjcf_model.sensor.add(
+ 'framepos', name=ego_pos_name,
+ objtype='body', objname=other.walker.root_body,
+ reftype='body', refname=player.walker.root_body)
+ player.walker.obs_on_other[ego_pos_name] = [ego_pos_sensor]
+ player.walker.observables.add_observable(
+ ego_pos_name,
+ base_observable.MJCFFeature('sensordata', ego_pos_sensor))
+
+ sensors_rot = []
+ obsname = '{}_ego_orientation'.format(prefix)
+ for direction in ['x', 'y', 'z']:
+ sensorname = obsname + '_' + direction
+ sensors_rot.append(player.walker.mjcf_model.sensor.add(
+ 'frame'+direction+'axis', name=sensorname,
+ objtype='body', objname=other.walker.root_body,
+ reftype='body', refname=player.walker.root_body))
+ def _egocentric_orientation(physics):
+ return np.reshape(physics.bind(sensors_rot).sensordata, -1)
+ player.walker.obs_on_other[obsname] = sensors_rot
+ player.walker.observables.add_observable(
+ obsname,
+ base_observable.Generic(_egocentric_orientation))
+
+ # Adds end effectors of the other agents in the other's egocentric frame.
+ # A is seeing B's hand extended to B's right.
+ player.walker.observables.add_observable(
+ '{}_end_effectors_pos'.format(prefix),
+ other.walker.observables.end_effectors_pos)
+
+ def _add_player_observables_on_ball(self, player, ball):
+ """Add observables of the soccer ball in this player's egocentric frame.
+
+ Args:
+ player: A `Walker` instance, the player we are adding observations for.
+ ball: A `SoccerBall` instance.
+ """
+ # Add egocentric ball observations.
+ player.walker.ball_ego_angvel_sensor = player.walker.mjcf_model.sensor.add(
+ 'frameangvel', name='ball_ego_angvel',
+ objtype='body', objname=ball.root_body,
+ reftype='body', refname=player.walker.root_body)
+ player.walker.observables.add_observable(
+ 'ball_ego_angular_velocity',
+ base_observable.MJCFFeature('sensordata',
+ player.walker.ball_ego_angvel_sensor))
+
+ player.walker.ball_ego_pos_sensor = player.walker.mjcf_model.sensor.add(
+ 'framepos', name='ball_ego_pos',
+ objtype='body', objname=ball.root_body,
+ reftype='body', refname=player.walker.root_body)
+ player.walker.observables.add_observable(
+ 'ball_ego_position',
+ base_observable.MJCFFeature('sensordata',
+ player.walker.ball_ego_pos_sensor))
+
+ player.walker.ball_ego_linvel_sensor = player.walker.mjcf_model.sensor.add(
+ 'framelinvel', name='ball_ego_linvel',
+ objtype='body', objname=ball.root_body,
+ reftype='body', refname=player.walker.root_body)
+ player.walker.observables.add_observable(
+ 'ball_ego_linear_velocity',
+ base_observable.MJCFFeature('sensordata',
+ player.walker.ball_ego_linvel_sensor))
+
+ def _add_player_proprio_observables(self, player):
+ """Add proprioceptive observables to the given player.
+
+ Args:
+ player: A `Walker` instance, the player we are adding observations for.
+ """
+ for observable in (player.walker.observables.proprioception +
+ player.walker.observables.kinematic_sensors):
+ observable.enabled = True
+
+ # Also enable previous action observable as part of proprioception.
+ player.walker.observables.prev_action.enabled = True
+
+ def _add_player_arena_observables(self, player, arena):
+ """Add observables of the arena.
+
+ Args:
+ player: A `Walker` instance to which observables will be added.
+ arena: A `Pitch` instance.
+ """
+ # Enable egocentric view of position detectors (goal, field).
+ # Corners named according to walker *facing towards opponent goal*.
+ clockwise_names = [
+ 'team_goal_back_right',
+ 'team_goal_mid',
+ 'team_goal_front_left',
+ 'field_front_left',
+ 'opponent_goal_back_left',
+ 'opponent_goal_mid',
+ 'opponent_goal_front_right',
+ 'field_back_right',
+ ]
+ clockwise_features = [
+ lambda _: arena.home_goal.lower[:2],
+ lambda _: arena.home_goal.mid,
+ lambda _: arena.home_goal.upper[:2],
+ lambda _: arena.field.upper,
+ lambda _: arena.away_goal.upper[:2],
+ lambda _: arena.away_goal.mid,
+ lambda _: arena.away_goal.lower[:2],
+ lambda _: arena.field.lower,
+ ]
+ xpos_xyz_callable = lambda p: p.bind(player.walker.root_body).xpos
+ xpos_xy_callable = lambda p: p.bind(player.walker.root_body).xpos[:2]
+ # A list of egocentric reference origin for each one of clockwise_features.
+ clockwise_origins = [
+ xpos_xy_callable,
+ xpos_xyz_callable,
+ xpos_xy_callable,
+ xpos_xy_callable,
+ xpos_xy_callable,
+ xpos_xyz_callable,
+ xpos_xy_callable,
+ xpos_xy_callable,
+ ]
+ if player.team != team_lib.Team.HOME:
+ half = len(clockwise_features) // 2
+ clockwise_features = clockwise_features[half:] + clockwise_features[:half]
+ clockwise_origins = clockwise_origins[half:] + clockwise_origins[:half]
+
+ for name, feature, origin in zip(clockwise_names, clockwise_features,
+ clockwise_origins):
+ player.walker.observables.add_egocentric_vector(
+ name, base_observable.Generic(feature), origin_callable=origin)
+
+ def _add_player_stats_observables(self, task, player):
+ """Add observables corresponding to game statistics.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+
+ def _stats_vel_to_ball(physics):
+ dir_ = (
+ physics.bind(task.ball.geom).xpos -
+ physics.bind(player.walker.root_body).xpos)
+ vel_to_ball = np.dot(dir_[:2] / (np.linalg.norm(dir_[:2]) + 1e-7),
+ physics.bind(player.walker.root_body).cvel[3:5])
+ return np.sum(vel_to_ball)
+
+ player.walker.observables.add_observable(
+ 'stats_vel_to_ball', base_observable.Generic(_stats_vel_to_ball))
+
+ def _stats_closest_vel_to_ball(physics):
+ """Velocity to the ball if this walker is the team's closest."""
+ closest = None
+ min_team_dist_to_ball = np.inf
+ for player_ in task.players:
+ if player_.team == player.team:
+ dist_to_ball = np.linalg.norm(
+ physics.bind(task.ball.geom).xpos -
+ physics.bind(player_.walker.root_body).xpos)
+ if dist_to_ball < min_team_dist_to_ball:
+ min_team_dist_to_ball = dist_to_ball
+ closest = player_
+ if closest is player:
+ return _stats_vel_to_ball(physics)
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_closest_vel_to_ball',
+ base_observable.Generic(_stats_closest_vel_to_ball))
+
+ def _stats_veloc_forward(physics):
+ """Player's forward velocity."""
+ return player.walker.observables.veloc_forward(physics)
+
+ player.walker.observables.add_observable(
+ 'stats_veloc_forward', base_observable.Generic(_stats_veloc_forward))
+
+ def _stats_vel_ball_to_goal(physics):
+ """Ball velocity towards opponents' goal."""
+ if player.team == team_lib.Team.HOME:
+ goal = task.arena.away_goal
+ else:
+ goal = task.arena.home_goal
+
+ goal_center = (goal.upper + goal.lower) / 2.
+ direction = goal_center - physics.bind(task.ball.geom).xpos
+ ball_vel_observable = task.ball.observables.linear_velocity
+ ball_vel = ball_vel_observable.observation_callable(physics)()
+
+ norm_dir = np.linalg.norm(direction)
+ normalized_dir = direction / norm_dir if norm_dir else direction
+ return np.sum(np.dot(normalized_dir, ball_vel))
+
+ player.walker.observables.add_observable(
+ 'stats_vel_ball_to_goal',
+ base_observable.Generic(_stats_vel_ball_to_goal))
+
+ def _stats_avg_teammate_dist(physics):
+ """Compute average distance from `walker` to its teammates."""
+ teammate_dists = []
+ for other in task.players:
+ if player is other:
+ continue
+ if other.team != player.team:
+ continue
+ dist = np.linalg.norm(
+ physics.bind(player.walker.root_body).xpos -
+ physics.bind(other.walker.root_body).xpos)
+ teammate_dists.append(dist)
+ return np.mean(teammate_dists) if teammate_dists else 0.
+
+ player.walker.observables.add_observable(
+ 'stats_home_avg_teammate_dist',
+ base_observable.Generic(_stats_avg_teammate_dist))
+
+ def _stats_teammate_spread_out(physics):
+ """Compute average distance from `walker` to its teammates."""
+ return _stats_avg_teammate_dist(physics) > 5.
+
+ player.walker.observables.add_observable(
+ 'stats_teammate_spread_out',
+ base_observable.Generic(_stats_teammate_spread_out))
+
+ def _stats_home_score(unused_physics):
+ if (task.arena.detected_goal() and
+ task.arena.detected_goal() == player.team):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_home_score', base_observable.Generic(_stats_home_score))
+
+ has_opponent = any([p.team != player.team for p in task.players])
+
+ def _stats_away_score(unused_physics):
+ if (has_opponent and task.arena.detected_goal() and
+ task.arena.detected_goal() != player.team):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_away_score', base_observable.Generic(_stats_away_score))
+
+
+# TODO(b/124848293): add unit-test interception observables.
+class InterceptionObservablesAdder(ObservablesAdder):
+ """Adds obervables representing interception events.
+
+ These observables represent events where this player received the ball from
+ another player, or when an opponent intercepted the ball from this player's
+ team. For each type of event there are three different thresholds applied to
+ the distance travelled by the ball since it last made contact with a player
+ (5, 10, or 15 meters).
+
+ For example, on a given timestep `stats_i_received_ball_10m` will be 1 if
+ * This player just made contact with the ball
+ * The last player to have made contact with the ball was a different player
+ * The ball travelled for at least 10 m since it last hit a player
+ and 0 otherwise.
+
+ Conversely, `stats_opponent_intercepted_ball_10m` will be 1 if:
+ * An opponent just made contact with the ball
+ * The last player to have made contact with the ball was on this player's team
+ * The ball travelled for at least 10 m since it last hit a player
+ """
+
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+
+ def _stats_i_received_ball(unused_physics):
+ if (task.ball.hit and task.ball.repossessed and
+ task.ball.last_hit is player):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_i_received_ball',
+ base_observable.Generic(_stats_i_received_ball))
+
+ def _stats_opponent_intercepted_ball(unused_physics):
+ """Indicator on if an opponent intercepted the ball."""
+ if (task.ball.hit and task.ball.intercepted and
+ task.ball.last_hit.team != player.team):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_opponent_intercepted_ball',
+ base_observable.Generic(_stats_opponent_intercepted_ball))
+
+ for dist in [5, 10, 15]:
+
+ def _stats_i_received_ball_dist(physics, dist=dist):
+ if (_stats_i_received_ball(physics) and
+ task.ball.dist_between_last_hits is not None and
+ task.ball.dist_between_last_hits > dist):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_i_received_ball_%dm' % dist,
+ base_observable.Generic(_stats_i_received_ball_dist))
+
+ def _stats_opponent_intercepted_ball_dist(physics, dist=dist):
+ if (_stats_opponent_intercepted_ball(physics) and
+ task.ball.dist_between_last_hits is not None and
+ task.ball.dist_between_last_hits > dist):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_opponent_intercepted_ball_%dm' % dist,
+ base_observable.Generic(_stats_opponent_intercepted_ball_dist))
diff --git a/dm_control/locomotion/soccer/pitch.py b/dm_control/locomotion/soccer/pitch.py
new file mode 100644
index 00000000..66b20f3e
--- /dev/null
+++ b/dm_control/locomotion/soccer/pitch.py
@@ -0,0 +1,724 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A soccer pitch with home/away goals and one field with position detection."""
+
+import colorsys
+import os
+
+from absl import logging
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.locomotion.soccer import team
+import numpy as np
+
+from dm_control.utils import io as resources
+
+_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'pitch')
+
+
+def _get_texture(name):
+ contents = resources.GetResource(
+ os.path.join(_ASSETS_PATH, '{}.png'.format(name)))
+ return mjcf.Asset(contents, '.png')
+
+
+_TOP_CAMERA_Y_PADDING_FACTOR = 1.1
+_TOP_CAMERA_DISTANCE = 95.
+_WALL_HEIGHT = 10.
+_WALL_THICKNESS = .5
+_SIDE_WIDTH = 32. / 6.
+_GROUND_GEOM_GRID_RATIO = 1. / 100 # Grid size for lighting.
+_FIELD_BOX_CONTACT_BIT = 1 << 7 # Use a higher bit to prevent potential clash.
+
+_DEFAULT_PITCH_SIZE = (12, 9)
+_DEFAULT_GOAL_LENGTH_RATIO = 0.33 # Goal length / pitch width.
+
+_GOALPOST_RELATIVE_SIZE = 0.07 # Ratio of the goalpost radius to goal size.
+_NET_RELATIVE_SIZE = 0.01 # Ratio of the net thickness to goal size.
+_SUPPORT_POST_RATIO = 0.75 # Ratio of support post to goalpost radius.
+# Goalposts defined in the unit box [-1, 1]**3 facing to the positive X.
+_GOALPOSTS = {'right_post': (1, -1, -1, 1, -1, 1),
+ 'left_post': (1, 1, -1, 1, 1, 1),
+ 'top_post': (1, -1, 1, 1, 1, 1),
+ 'right_base': (1, -1, -1, -1, -1, -1),
+ 'left_base': (1, 1, -1, -1, 1, -1),
+ 'back_base': (-1, -1, -1, -1, 1, -1),
+ 'right_support': (-1, -1, -1, .2, -1, 1),
+ 'right_top_support': (.2, -1, 1, 1, -1, 1),
+ 'left_support': (-1, 1, -1, .2, 1, 1),
+ 'left_top_support': (.2, 1, 1, 1, 1, 1)}
+# Vertices of net polygons, reshaped to 4x3 arrays.
+_NET = {'top': _GOALPOSTS['right_top_support'] + _GOALPOSTS['left_top_support'],
+ 'back': _GOALPOSTS['right_support'] + _GOALPOSTS['left_support'],
+ 'left': _GOALPOSTS['left_base'] + _GOALPOSTS['left_top_support'],
+ 'right': _GOALPOSTS['right_base'] + _GOALPOSTS['right_top_support']}
+_NET = {key: np.array(value).reshape(4, 3) for key, value in _NET.items()}
+
+# Number of visual hoarding boxes per side of the pitch.
+_NUM_HOARDING = 30
+
+
+def _top_down_cam_fovy(size, top_camera_distance):
+ return (360 / np.pi) * np.arctan2(_TOP_CAMERA_Y_PADDING_FACTOR * max(size),
+ top_camera_distance)
+
+
+def _wall_pos_xyaxes(size):
+ """Infers position and size of bounding walls given pitch size.
+
+ Walls are placed around `ground_geom` that represents the pitch. Note that
+ the ball cannot travel beyond `field` but walkers can walk outside of the
+ `field` but not the surrounding walls.
+
+ Args:
+ size: a tuple of (length, width) of the pitch.
+
+ Returns:
+ a list of 4 tuples, each representing the position and xyaxes of a wall
+ plane. In order, walls are placed along x-negative, x-positive, y-negative,
+ y-positive relative the center of the pitch.
+ """
+ return [
+ ((0., -size[1], 0.), (-1, 0, 0, 0, 0, 1)),
+ ((0., size[1], 0.), (1, 0, 0, 0, 0, 1)),
+ ((-size[0], 0., 0.), (0, 1, 0, 0, 0, 1)),
+ ((size[0], 0., 0.), (0, -1, 0, 0, 0, 1)),
+ ]
+
+
+def _fieldbox_pos_size(field_size, goal_size):
+ """Infers position and size of fieldbox given pitch size.
+
+ Walls are placed around the field so that the ball cannot travel beyond
+ `field` but walkers can walk outside of the `field` but not the surrounding
+ pitch. Holes are left in the fieldbox at the goal positions to enable scoring.
+
+ Args:
+ field_size: a tuple of (length, width) of the field.
+ goal_size: a tuple of (unused_depth, width, height) of the goal.
+
+ Returns:
+ a list of 8 tuples, each representing the position and size of a wall box.
+ """
+
+ box_half_height = 20.
+ corner_pos_y = 0.5 * (field_size[1] + goal_size[1])
+ corner_size_y = 0.5 * (field_size[1] - goal_size[1])
+ thickness = 1.0
+ top_pos_z = box_half_height + goal_size[2]
+ top_size_z = box_half_height - goal_size[2]
+ wall_offset_x = field_size[0] + thickness
+ wall_offset_y = field_size[1] + thickness
+ return [
+ ((0., -wall_offset_y, box_half_height),
+ (field_size[0], thickness, box_half_height)), # near side
+ ((0., wall_offset_y, box_half_height),
+ (field_size[0], thickness, box_half_height)), # far side
+ ((-wall_offset_x, -corner_pos_y, box_half_height),
+ (thickness, corner_size_y, box_half_height)), # left near corner
+ ((-wall_offset_x, 0., top_pos_z),
+ (thickness, goal_size[1], top_size_z)), # left top corner
+ ((-wall_offset_x, corner_pos_y, box_half_height),
+ (thickness, corner_size_y, box_half_height)), # left far corner
+ ((wall_offset_x, -corner_pos_y, box_half_height),
+ (thickness, corner_size_y, box_half_height)), # right near corner
+ ((wall_offset_x, 0., top_pos_z),
+ (thickness, goal_size[1], top_size_z)), # right top corner
+ ((wall_offset_x, corner_pos_y, box_half_height),
+ (thickness, corner_size_y, box_half_height)), # right far corner
+ ]
+
+
+def _roof_size(size):
+ return (size[0], size[1], _WALL_THICKNESS)
+
+
+def _reposition_corner_lights(lights, size):
+ """Place four lights at the corner of the pitch."""
+ mean_size = 0.5 * sum(size)
+ height = mean_size * 2/3
+ counter = 0
+ for x in [-size[0], size[0]]:
+ for y in [-size[1], size[1]]:
+ position = np.array((x, y, height))
+ direction = -np.array((x, y, height*2))
+ lights[counter].pos = position
+ lights[counter].dir = direction
+ counter += 1
+
+
+def _goalpost_radius(size):
+ """Compute goal post radius as scaled average goal size."""
+ return _GOALPOST_RELATIVE_SIZE * sum(size) / 3.
+
+
+def _post_radius(goalpost_name, goalpost_radius):
+ """Compute the radius of a specific goalpost."""
+ radius = goalpost_radius
+ if 'top' in goalpost_name:
+ radius *= 1.01 # Prevent z-fighting at the corners.
+ if 'support' in goalpost_name:
+ radius *= _SUPPORT_POST_RATIO # Suport posts are a bit narrower.
+ return radius
+
+
+def _goalpost_fromto(unit_fromto, size, pos, direction):
+ """Rotate, scale and translate the `fromto` attribute of a goalpost.
+
+ The goalposts are defined in the unit cube [-1, 1]**3 using MuJoCo fromto
+ specifier for capsules, they are then flipped according to whether they face
+ in the +x or -x, scaled and moved.
+
+ Args:
+ unit_fromto: two concatenated 3-vectors in the unit cube in xyzxyz order.
+ size: a 3-vector, scaling of the goal.
+ pos: a 3-vector, goal position.
+ direction: a 3-vector, either (1,1,1) or (-1,-1,1), direction of the goal
+ along the x-axis.
+
+ Returns:
+ two concatenated 3-vectors, the `fromto` of a goal geom.
+ """
+ fromto = np.array(unit_fromto) * np.hstack((direction, direction))
+ return fromto*np.array(size+size) + np.array(pos+pos)
+
+
+class Goal(props.PositionDetector):
+ """Goal for soccer-like games: A PositionDetector with goalposts."""
+
+ def _make_net_vertices(self, size=(1, 1, 1)):
+ """Make vertices for the four net meshes by offsetting net polygons."""
+ thickness = _NET_RELATIVE_SIZE * sum(size) / 3
+ # Get mesh offsets, compensate for mesh.scale deformation.
+ dx = np.array((thickness / size[0], 0, 0))
+ dy = np.array((0, thickness / size[1], 0))
+ dz = np.array((0, 0, thickness / size[2]))
+ # Make mesh vertices with specified thickness.
+ top = [v+dz for v in _NET['top']] + [v-dz for v in _NET['top']]
+ right = [v+dy for v in _NET['right']] + [v-dy for v in _NET['right']]
+ left = [v+dy for v in _NET['left']] + [v-dy for v in _NET['left']]
+ back = ([v+dz for v in _NET['back'] if v[2] == 1] +
+ [v-dz for v in _NET['back'] if v[2] == 1] +
+ [v+dx for v in _NET['back'] if v[2] == -1] +
+ [v-dx for v in _NET['back'] if v[2] == -1])
+ vertices = {'top': top, 'back': back, 'left': left, 'right': right}
+ return {key: (val*self._direction).flatten()
+ for key, val in vertices.items()}
+
+ def _move_goal(self, pos, size):
+ """Translate and scale the goal."""
+ for geom in self._goal_geoms:
+ unit_fromto = _GOALPOSTS[geom.name]
+ geom.fromto = _goalpost_fromto(unit_fromto, size, pos, self._direction)
+ geom.size = (_post_radius(geom.name, self._goalpost_radius),)
+ if self._make_net:
+ net_vertices = self._make_net_vertices(size)
+ for geom in self._net_geoms:
+ geom.pos = pos
+ geom.mesh.vertex = net_vertices[geom.mesh.name]
+ geom.mesh.scale = size
+
+ def _build(self, direction, net_rgba=(1, 1, 1, .15), make_net=True, **kwargs):
+ """Builds the goalposts and net.
+
+ Args:
+ direction: Is the goal oriented towards positive or negative x-axis.
+ net_rgba: rgba value of the net geoms.
+ make_net: Where to add net geoms.
+ **kwargs: arguments of PositionDetector superclass, see therein.
+
+ Raises:
+ ValueError: If either `pos` or `size` arrays are not of length 3.
+ ValueError: If direction in not 1 or -1.
+ """
+ if len(kwargs['size']) != 3 or len(kwargs['pos']) != 3:
+ raise ValueError('Only 3D Goals are supported.')
+ if direction not in [1, -1]:
+ raise ValueError('direction must be either 1 or -1.')
+ # Flip both x and y, to maintain left / right name correctness.
+ self._direction = np.array((direction, direction, 1))
+ self._make_net = make_net
+
+ # Force the underlying PositionDetector to a non visible site group.
+ kwargs['visible'] = False
+ # Make a Position_Detector.
+ super()._build(retain_substep_detections=True, **kwargs)
+
+ # Add goalpost geoms.
+ size = kwargs['size']
+ pos = kwargs['pos']
+ self._goalpost_radius = _goalpost_radius(size)
+ self._goal_geoms = []
+ for geom_name, unit_fromto in _GOALPOSTS.items():
+ geom_fromto = _goalpost_fromto(unit_fromto, size, pos, self._direction)
+ geom_size = (_post_radius(geom_name, self._goalpost_radius),)
+ self._goal_geoms.append(
+ self._mjcf_root.worldbody.add(
+ 'geom',
+ type='capsule',
+ name=geom_name,
+ size=geom_size,
+ fromto=geom_fromto,
+ rgba=self.goalpost_rgba))
+
+ # Add net meshes and geoms.
+ if self._make_net:
+ net_vertices = self._make_net_vertices()
+ self._net_geoms = []
+ for name, vertex in net_vertices.items():
+ mesh = self._mjcf_root.asset.add('mesh', name=name, vertex=vertex)
+ geom = self._mjcf_root.worldbody.add('geom', type='mesh', mesh=mesh,
+ name=name, rgba=net_rgba,
+ contype=0, conaffinity=0)
+ self._net_geoms.append(geom)
+
+ def resize(self, pos, size):
+ """Call PositionDetector.resize(), move the goal."""
+ super().resize(pos, size)
+ self._goalpost_radius = _goalpost_radius(size)
+ self._move_goal(pos, size)
+
+ def set_position(self, physics, pos):
+ """Call PositionDetector.set_position(), move the goal."""
+ super().set_position(pos)
+ size = 0.5*(self.upper - self.lower)
+ self._move_goal(pos, size)
+
+ def _update_detection(self, physics):
+ """Call PositionDetector._update_detection(), then recolor the goalposts."""
+ super()._update_detection(physics)
+ if self._detected and not self._previously_detected:
+ physics.bind(self._goal_geoms).rgba = self.goalpost_detected_rgba
+ elif self._previously_detected and not self._detected:
+ physics.bind(self._goal_geoms).rgba = self.goalpost_rgba
+
+ @property
+ def goalpost_rgba(self):
+ """Goalposts are always opaque."""
+ rgba = self._rgba.copy()
+ rgba[3] = 1
+ return rgba
+
+ @property
+ def goalpost_detected_rgba(self):
+ """Goalposts are always opaque."""
+ detected_rgba = self._detected_rgba.copy()
+ detected_rgba[3] = 1
+ return detected_rgba
+
+
+class Pitch(composer.Arena):
+ """A pitch with a plane, two goals and a field with position detection."""
+
+ def _build(self,
+ size=_DEFAULT_PITCH_SIZE,
+ goal_size=None,
+ top_camera_distance=_TOP_CAMERA_DISTANCE,
+ field_box=False,
+ field_box_offset=0.0,
+ hoarding_color_scheme_id=0,
+ name='pitch'):
+ """Construct a pitch with walls and position detectors.
+
+ Args:
+ size: a tuple of (length, width) of the pitch.
+ goal_size: optional (depth, width, height) indicating the goal size.
+ If not specified, the goal size is inferred from pitch size with a fixed
+ default ratio.
+ top_camera_distance: the distance of the top-down camera to the pitch.
+ field_box: adds a "field box" that collides with the ball but not the
+ walkers.
+ field_box_offset: offset for the fieldbox if used.
+ hoarding_color_scheme_id: An integer with value 0, 1, 2, or 3, specifying
+ a preset scheme for the hoarding colors.
+ name: the name of this arena.
+ """
+ super()._build(name=name)
+ self._size = size
+ self._goal_size = goal_size
+ self._top_camera_distance = top_camera_distance
+ self._hoarding_color_scheme_id = hoarding_color_scheme_id
+
+ self._top_camera = self._mjcf_root.worldbody.add(
+ 'camera',
+ name='top_down',
+ pos=[0, 0, top_camera_distance],
+ zaxis=[0, 0, 1],
+ fovy=_top_down_cam_fovy(self._size, top_camera_distance))
+
+ # Set the `extent`, an "average distance" to 0.1 * pitch length.
+ extent = 0.1 * max(self._size)
+ self._mjcf_root.statistic.extent = extent
+ self._mjcf_root.statistic.center = (0, 0, extent)
+ # The near and far clipping planes are scaled by `extent`.
+ self._mjcf_root.visual.map.zfar = 50 # 5 pitch lengths
+ self._mjcf_root.visual.map.znear = 0.1 / extent # 10 centimeters
+
+ # Add skybox.
+ self._mjcf_root.asset.add(
+ 'texture',
+ name='skybox',
+ type='skybox',
+ builtin='gradient',
+ rgb1=(.7, .9, .9),
+ rgb2=(.03, .09, .27),
+ width=400,
+ height=400)
+
+ # Add and position corner lights.
+ self._corner_lights = [self._mjcf_root.worldbody.add('light', cutoff=60)
+ for _ in range(4)]
+ _reposition_corner_lights(self._corner_lights, size)
+
+ # Increase shadow resolution, (default is 1024).
+ self._mjcf_root.visual.quality.shadowsize = 8192
+
+ # Build groundplane.
+ if len(self._size) != 2:
+ raise ValueError('`size` should be a sequence of length 2: got {!r}'
+ .format(self._size))
+ self._field_texture = self._mjcf_root.asset.add(
+ 'texture',
+ type='2d',
+ file=_get_texture('pitch_nologo_l'),
+ name='fieldplane')
+ self._field_material = self._mjcf_root.asset.add(
+ 'material', name='fieldplane', texture=self._field_texture)
+
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ name='ground',
+ type='plane',
+ material=self._field_material,
+ size=list(self._size) + [max(self._size) * _GROUND_GEOM_GRID_RATIO])
+
+ # Build walls.
+ self._walls = []
+ for wall_pos, wall_xyaxes in _wall_pos_xyaxes(self._size):
+ self._walls.append(
+ self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ rgba=[.1, .1, .1, .8],
+ pos=wall_pos,
+ size=[1e-7, 1e-7, 1e-7],
+ xyaxes=wall_xyaxes))
+
+ # Build goal position detectors.
+ # If field_box is enabled, offset goal by 1.0 such that ball reaches the
+ # goal position detector before bouncing off the field_box.
+ self._fb_offset = field_box_offset if field_box else 0.0
+ goal_size = self._get_goal_size()
+ self._home_goal = Goal(
+ direction=1,
+ make_net=False,
+ pos=(-self._size[0] + goal_size[0] + self._fb_offset, 0,
+ goal_size[2]),
+ size=goal_size,
+ rgba=(.2, .2, 1, 0.5),
+ visible=True,
+ name='home_goal')
+ self.attach(self._home_goal)
+
+ self._away_goal = Goal(
+ direction=-1,
+ make_net=False,
+ pos=(self._size[0] - goal_size[0] - self._fb_offset, 0, goal_size[2]),
+ size=goal_size,
+ rgba=(1, .2, .2, 0.5),
+ visible=True,
+ name='away_goal')
+ self.attach(self._away_goal)
+
+ # Build inverted field position detectors.
+ self._field = props.PositionDetector(
+ pos=(0, 0),
+ size=(self._size[0] - 2 * goal_size[0],
+ self._size[1] - 2 * goal_size[0]),
+ inverted=True,
+ visible=False,
+ name='field')
+ self.attach(self._field)
+
+ # Build field perimeter.
+ def _visual_plane():
+ return self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ size=(1, 1, 1),
+ rgba=(0.306, 0.682, 0.223, 1),
+ contype=0,
+ conaffinity=0)
+
+ self._perimeter = [_visual_plane() for _ in range(8)]
+ self._update_perimeter()
+
+ # Build field box.
+ self._field_box = []
+ if field_box:
+ for box_pos, box_size in _fieldbox_pos_size(
+ (self._field.upper - self._field.lower) / 2.0, goal_size):
+ self._field_box.append(
+ self._mjcf_root.worldbody.add(
+ 'geom',
+ type='box',
+ rgba=[.3, .3, .3, .0],
+ pos=box_pos,
+ size=box_size))
+
+ # Build hoarding sites.
+ def _box_site():
+ return self._mjcf_root.worldbody.add('site', type='box', size=(1, 1, 1))
+ self._hoarding = [_box_site() for _ in range(4 * _NUM_HOARDING)]
+ self._update_hoarding()
+
+ def _update_hoarding(self):
+ # Resize, reposition and re-color visual perimeter box geoms.
+ num_boxes = _NUM_HOARDING
+ counter = 0
+ for dim in [0, 1]: # Semantics are [x, y]
+ width = self._get_goal_size()[2] / 8 # Eighth of the goal height.
+ height = self._get_goal_size()[2] / 2 # Half of the goal height.
+ length = self._size[dim]
+ if dim == 1: # Stretch the y-dim length in order to cover the corners.
+ length += 2 * width
+ box_size = height * np.ones(3)
+ box_size[dim] = length / num_boxes
+ box_size[1-dim] = width
+ dim_pos = np.linspace(-length, length, num_boxes, endpoint=False)
+ dim_pos += length / num_boxes # Offset to center.
+ for sign in [-1, 1]:
+ alt_pos = sign * (self._size[1-dim] * np.ones(num_boxes) + width)
+ dim_alt = (dim_pos, alt_pos)
+ for box in range(num_boxes):
+ box_pos = np.array((dim_alt[dim][box], dim_alt[1-dim][box], width))
+ if self._hoarding_color_scheme_id == 0:
+ # Red to blue through green + blue hoarding behind blue goal
+ angle = np.pi + np.arctan2(box_pos[0], -np.abs(box_pos[1]))
+ elif self._hoarding_color_scheme_id == 1:
+ # Red to blue through green + blue hoarding behind red goal
+ angle = np.arctan2(box_pos[0], np.abs(box_pos[1]))
+ elif self._hoarding_color_scheme_id == 2:
+ # Red to blue through purple + blue hoarding behind red goal
+ angle = np.arctan2(box_pos[0], -np.abs(box_pos[1]))
+ elif self._hoarding_color_scheme_id == 3:
+ # Red to blue through purple + blue hoarding behind blue goal
+ angle = np.pi + np.arctan2(box_pos[0], np.abs(box_pos[1]))
+ hue = 0.5 + angle / (2*np.pi) # In [0, 1]
+ hue_offset = .25
+ hue = (hue - hue_offset) % 1.0 # Apply offset and wrap back to [0, 1]
+ saturation = .7
+ value = 1.0
+ col_r, col_g, col_b = colorsys.hsv_to_rgb(hue, saturation, value)
+ self._hoarding[counter].pos = box_pos
+ self._hoarding[counter].size = box_size
+ self._hoarding[counter].rgba = (col_r, col_g, col_b, 1.)
+ counter += 1
+
+ def _update_perimeter(self):
+ # Resize and reposition visual perimeter plane geoms.
+ width = self._get_goal_size()[0]
+ counter = 0
+ for x in [-1, 0, 1]:
+ for y in [-1, 0, 1]:
+ if x == 0 and y == 0:
+ continue
+ size_0 = self._size[0]-2*width if x == 0 else width
+ size_1 = self._size[1]-2*width if y == 0 else width
+ size = [size_0, size_1, max(self._size) * _GROUND_GEOM_GRID_RATIO]
+ pos = (x*(self._size[0]-width), y*(self._size[1]-width), 0)
+ self._perimeter[counter].size = size
+ self._perimeter[counter].pos = pos
+ counter += 1
+
+ def _get_goal_size(self):
+ goal_size = self._goal_size
+ if goal_size is None:
+ goal_size = (
+ _SIDE_WIDTH / 2,
+ self._size[1] * _DEFAULT_GOAL_LENGTH_RATIO,
+ _SIDE_WIDTH / 2,
+ )
+ return goal_size
+
+ def register_ball(self, ball):
+ self._home_goal.register_entities(ball)
+ self._away_goal.register_entities(ball)
+
+ if self._field_box:
+ # Geoms a and b collides if:
+ # (a.contype & b.conaffinity) || (b.contype & a.conaffinity) != 0.
+ # See: http://www.mujoco.org/book/computation.html#Collision
+ ball.geom.contype = (ball.geom.contype or 1) | _FIELD_BOX_CONTACT_BIT
+ for wall in self._field_box:
+ wall.conaffinity = _FIELD_BOX_CONTACT_BIT
+ wall.contype = _FIELD_BOX_CONTACT_BIT
+ else:
+ self._field.register_entities(ball)
+
+ def detected_goal(self):
+ """Returning the team that scored a goal."""
+ if self._home_goal.detected_entities:
+ return team.Team.AWAY
+ if self._away_goal.detected_entities:
+ return team.Team.HOME
+ return None
+
+ def detected_off_court(self):
+ return self._field.detected_entities
+
+ @property
+ def size(self):
+ return self._size
+
+ @property
+ def home_goal(self):
+ return self._home_goal
+
+ @property
+ def away_goal(self):
+ return self._away_goal
+
+ @property
+ def field(self):
+ return self._field
+
+ @property
+ def ground_geom(self):
+ return self._ground_geom
+
+
+class RandomizedPitch(Pitch):
+ """RandomizedPitch that randomizes its size between (min_size, max_size)."""
+
+ def __init__(self,
+ min_size,
+ max_size,
+ randomizer=None,
+ keep_aspect_ratio=False,
+ goal_size=None,
+ field_box=False,
+ field_box_offset=0.0,
+ top_camera_distance=_TOP_CAMERA_DISTANCE,
+ name='randomized_pitch'):
+ """Construct a randomized pitch.
+
+ Args:
+ min_size: a tuple of minimum (length, width) of the pitch.
+ max_size: a tuple of maximum (length, width) of the pitch.
+ randomizer: a callable that returns ratio between [0., 1.] that scales
+ between min_size, max_size.
+ keep_aspect_ratio: if `True`, keep the aspect ratio constant during
+ randomization.
+ goal_size: optional (depth, width, height) indicating the goal size.
+ If not specified, the goal size is inferred from pitch size with a fixed
+ default ratio.
+ field_box: optional indicating if we should construct field box containing
+ the ball (but not the walkers).
+ field_box_offset: offset for the fieldbox if used.
+ top_camera_distance: the distance of the top-down camera to the pitch.
+ name: the name of this arena.
+ """
+ super().__init__(
+ size=max_size,
+ goal_size=goal_size,
+ top_camera_distance=top_camera_distance,
+ field_box=field_box,
+ field_box_offset=field_box_offset,
+ name=name)
+
+ self._min_size = min_size
+ self._max_size = max_size
+
+ self._randomizer = randomizer or distributions.Uniform()
+ self._keep_aspect_ratio = keep_aspect_ratio
+
+ # Sample a new size and regenerate the soccer pitch.
+ logging.info('%s between (%s, %s) with %s', self.__class__.__name__,
+ min_size, max_size, self._randomizer)
+
+ def _resize_goals(self, goal_size):
+ self._home_goal.resize(
+ pos=(-self._size[0] + goal_size[0] + self._fb_offset, 0, goal_size[2]),
+ size=goal_size)
+ self._away_goal.resize(
+ pos=(self._size[0] - goal_size[0] - self._fb_offset, 0, goal_size[2]),
+ size=goal_size)
+
+ def initialize_episode_mjcf(self, random_state):
+ super().initialize_episode_mjcf(random_state)
+ min_len, min_wid = self._min_size
+ max_len, max_wid = self._max_size
+
+ if self._keep_aspect_ratio:
+ len_ratio = self._randomizer(random_state=random_state)
+ wid_ratio = len_ratio
+ else:
+ len_ratio = self._randomizer(random_state=random_state)
+ wid_ratio = self._randomizer(random_state=random_state)
+
+ self._size = (min_len + len_ratio * (max_len - min_len),
+ min_wid + wid_ratio * (max_wid - min_wid))
+
+ # Reset top_down camera field of view.
+ self._top_camera.fovy = _top_down_cam_fovy(self._size,
+ self._top_camera_distance)
+
+ # Resize ground perimeter.
+ self._update_perimeter()
+
+ # Resize and reposition walls and roof geoms.
+ for i, (wall_pos, _) in enumerate(_wall_pos_xyaxes(self._size)):
+ self._walls[i].pos = wall_pos
+
+ goal_size = self._get_goal_size()
+ self._resize_goals(goal_size)
+
+ # Resize inverted field position detectors.
+ field_size = (self._size[0] -2*goal_size[0], self._size[1] -2*goal_size[0])
+ self._field.resize(pos=(0, 0), size=field_size)
+
+ # Resize ground geom size.
+ self._ground_geom.size = list(
+ field_size) + [max(self._size) * _GROUND_GEOM_GRID_RATIO]
+
+ # Resize and reposition field box geoms.
+ if self._field_box:
+ for i, (pos, size) in enumerate(
+ _fieldbox_pos_size((self._field.upper - self._field.lower) / 2.0,
+ goal_size)):
+ self._field_box[i].pos = pos
+ self._field_box[i].size = size
+
+ # Reposition corner lights.
+ _reposition_corner_lights(
+ self._corner_lights,
+ size=(self._size[0] - 2 * goal_size[0],
+ self._size[1] - 2 * goal_size[0]))
+
+ # Resize, reposition and recolor hoarding geoms.
+ self._update_hoarding()
+
+
+# Mini-football (5v5) dimensions.
+_GOAL_LENGTH = 3.66
+_GOAL_SIDE = 1.22
+
+MINI_FOOTBALL_MIN_AREA_PER_HUMANOID = 100.0
+MINI_FOOTBALL_MAX_AREA_PER_HUMANOID = 350.0
+MINI_FOOTBALL_GOAL_SIZE = (_GOAL_SIDE / 2, _GOAL_LENGTH / 2, _GOAL_SIDE / 2)
diff --git a/dm_control/locomotion/soccer/pitch_test.py b/dm_control/locomotion/soccer/pitch_test.py
new file mode 100644
index 00000000..e845dedd
--- /dev/null
+++ b/dm_control/locomotion/soccer/pitch_test.py
@@ -0,0 +1,83 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.pitch."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.locomotion.soccer import pitch as pitch_lib
+from dm_control.locomotion.soccer import team as team_lib
+import numpy as np
+
+
+class PitchTest(parameterized.TestCase):
+
+ def _pitch_with_ball(self, pitch_size, ball_pos):
+ pitch = pitch_lib.Pitch(size=pitch_size)
+ self.assertEqual(pitch.size, pitch_size)
+
+ sphere = props.Primitive(geom_type='sphere', size=(0.1,), pos=ball_pos)
+ pitch.register_ball(sphere)
+ pitch.attach(sphere)
+
+ env = composer.Environment(
+ composer.NullTask(pitch), random_state=np.random.RandomState(42))
+ env.reset()
+ return pitch
+
+ def test_pitch_none_detected(self):
+ pitch = self._pitch_with_ball((12, 9), (0, 0, 0))
+ self.assertEmpty(pitch.detected_off_court())
+ self.assertIsNone(pitch.detected_goal())
+
+ def test_pitch_detected_off_court(self):
+ pitch = self._pitch_with_ball((12, 9), (20, 0, 0))
+ self.assertLen(pitch.detected_off_court(), 1)
+ self.assertIsNone(pitch.detected_goal())
+
+ def test_pitch_detected_away_goal(self):
+ pitch = self._pitch_with_ball((12, 9), (-9.5, 0, 1))
+ self.assertLen(pitch.detected_off_court(), 1)
+ self.assertEqual(team_lib.Team.AWAY, pitch.detected_goal())
+
+ def test_pitch_detected_home_goal(self):
+ pitch = self._pitch_with_ball((12, 9), (9.5, 0, 1))
+ self.assertLen(pitch.detected_off_court(), 1)
+ self.assertEqual(team_lib.Team.HOME, pitch.detected_goal())
+
+ @parameterized.parameters((True, distributions.Uniform()),
+ (False, distributions.Uniform()))
+ def test_randomize_pitch(self, keep_aspect_ratio, randomizer):
+ pitch = pitch_lib.RandomizedPitch(
+ min_size=(4, 3),
+ max_size=(8, 6),
+ randomizer=randomizer,
+ keep_aspect_ratio=keep_aspect_ratio)
+ pitch.initialize_episode_mjcf(np.random.RandomState(42))
+
+ self.assertBetween(pitch.size[0], 4, 8)
+ self.assertBetween(pitch.size[1], 3, 6)
+
+ if keep_aspect_ratio:
+ self.assertAlmostEqual((pitch.size[0] - 4) / (8. - 4.),
+ (pitch.size[1] - 3) / (6. - 3.))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/soccer/soccer.png b/dm_control/locomotion/soccer/soccer.png
new file mode 100644
index 00000000..e63f8eb0
Binary files /dev/null and b/dm_control/locomotion/soccer/soccer.png differ
diff --git a/dm_control/locomotion/soccer/soccer_ball.py b/dm_control/locomotion/soccer/soccer_ball.py
new file mode 100644
index 00000000..366b1c9c
--- /dev/null
+++ b/dm_control/locomotion/soccer/soccer_ball.py
@@ -0,0 +1,262 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A soccer ball that keeps track of ball-player contacts."""
+
+import os
+
+from dm_control import mjcf
+from dm_control.entities import props
+import numpy as np
+
+from dm_control.utils import io as resources
+
+_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'soccer_ball')
+
+# FIFA regulation parameters for a size 5 ball.
+_REGULATION_RADIUS = 0.117 # Meters.
+_REGULATION_MASS = 0.45 # Kilograms.
+
+_DEFAULT_FRICTION = (0.7, 0.05, 0.04) # (slide, spin, roll).
+_DEFAULT_DAMP_RATIO = 0.4
+
+
+def _get_texture(name):
+ contents = resources.GetResource(
+ os.path.join(_ASSETS_PATH, '{}.png'.format(name)))
+ return mjcf.Asset(contents, '.png')
+
+
+def regulation_soccer_ball():
+ return SoccerBall(
+ radius=_REGULATION_RADIUS,
+ mass=_REGULATION_MASS,
+ friction=_DEFAULT_FRICTION,
+ damp_ratio=_DEFAULT_DAMP_RATIO)
+
+
+class SoccerBall(props.Primitive):
+ """A soccer ball that keeps track of entities that come into contact."""
+
+ def _build(self,
+ radius=0.35,
+ mass=0.045,
+ friction=(0.7, 0.075, 0.075),
+ damp_ratio=1.0,
+ name='soccer_ball'):
+ """Builds this soccer ball.
+
+ Args:
+ radius: The radius (in meters) of this target sphere.
+ mass: Mass (in kilograms) of the ball.
+ friction: Friction parameters of the ball geom with the three dimensions
+ corresponding to (slide, spin, roll) frictions.
+ damp_ratio: A real positive number. Lower implies less dampening upon
+ contacts.
+ name: The name of this entity.
+ """
+ super()._build(geom_type='sphere', size=(radius,), name=name)
+ texture = self._mjcf_root.asset.add(
+ 'texture',
+ name='soccer_ball',
+ type='cube',
+ fileup=_get_texture('up'),
+ filedown=_get_texture('down'),
+ filefront=_get_texture('front'),
+ fileback=_get_texture('back'),
+ fileleft=_get_texture('left'),
+ fileright=_get_texture('right'))
+ material = self._mjcf_root.asset.add(
+ 'material', name='soccer_ball', texture=texture)
+
+ if damp_ratio < 0.0:
+ raise ValueError(
+ f'Invalid `damp_ratio` parameter ({damp_ratio} is not positive).')
+
+ self._geom.set_attributes(
+ pos=[0, 0, radius],
+ size=[radius],
+ condim=6,
+ priority=1,
+ mass=mass,
+ friction=friction,
+ solref=[0.02, damp_ratio],
+ material=material)
+
+ # Add some tracking cameras for visualization and logging.
+ self._mjcf_root.worldbody.add(
+ 'camera',
+ name='ball_cam_near',
+ pos=[0, -2, 2],
+ zaxis=[0, -1, 1],
+ fovy=70,
+ mode='trackcom')
+ self._mjcf_root.worldbody.add(
+ 'camera',
+ name='ball_cam',
+ pos=[0, -7, 7],
+ zaxis=[0, -1, 1],
+ fovy=70,
+ mode='trackcom')
+ self._mjcf_root.worldbody.add(
+ 'camera',
+ name='ball_cam_far',
+ pos=[0, -10, 10],
+ zaxis=[0, -1, 1],
+ fovy=70,
+ mode='trackcom')
+
+ # Keep track of entities to team mapping.
+ self._players = []
+
+ # Initialize tracker attributes.
+ self.initialize_entity_trackers()
+
+ def register_player(self, player):
+ self._players.append(player)
+
+ def initialize_entity_trackers(self):
+ self._last_hit = None
+ self._hit = False
+ self._repossessed = False
+ self._intercepted = False
+
+ # Tracks distance traveled by the ball in between consecutive hits.
+ self._pos_at_last_step = None
+ self._dist_since_last_hit = None
+ self._dist_between_last_hits = None
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom')
+ self._geom_id_to_player = {}
+ for player in self._players:
+ geoms = player.walker.mjcf_model.find_all('geom')
+ for geom in geoms:
+ geom_id = physics.model.name2id(geom.full_identifier, 'geom')
+ self._geom_id_to_player[geom_id] = player
+
+ self.initialize_entity_trackers()
+
+ def after_substep(self, physics, unused_random_state):
+ """Resolve contacts and update ball-player contact trackers."""
+ if self._hit:
+ # Ball has already registered a valid contact within step (during one of
+ # previous after_substep calls).
+ return
+
+ # Iterate through all contacts to find the first contact between the ball
+ # and one of the registered entities.
+ for contact in physics.data.contact:
+ # Keep contacts that involve the ball and one of the registered entities.
+ has_self = False
+ for geom_id in (contact.geom1, contact.geom2):
+ if geom_id == self._geom_id:
+ has_self = True
+ else:
+ player = self._geom_id_to_player.get(geom_id)
+
+ if has_self and player:
+ # Detected a contact between the ball and an registered player.
+ if self._last_hit is not None:
+ self._intercepted = player.team != self._last_hit.team
+ else:
+ self._intercepted = True
+
+ # Register repossessed before updating last_hit player.
+ self._repossessed = player is not self._last_hit
+ self._last_hit = player
+ # Register hit event.
+ self._hit = True
+ break
+
+ def before_step(self, physics, random_state):
+ super().before_step(physics, random_state)
+ # Reset per simulation step indicator.
+ self._hit = False
+ self._repossessed = False
+ self._intercepted = False
+
+ def after_step(self, physics, random_state):
+ super().after_step(physics, random_state)
+ pos = physics.bind(self._geom).xpos
+ if self._hit:
+ # SoccerBall is hit on this step. Update dist_between_last_hits
+ # to dist_since_last_hit before resetting dist_since_last_hit.
+ self._dist_between_last_hits = self._dist_since_last_hit
+ self._dist_since_last_hit = 0.
+ self._pos_at_last_step = pos.copy()
+
+ if self._dist_since_last_hit is not None:
+ # Accumulate distance traveled since last hit event.
+ self._dist_since_last_hit += np.linalg.norm(pos - self._pos_at_last_step)
+
+ self._pos_at_last_step = pos.copy()
+
+ @property
+ def last_hit(self):
+ """The player that last came in contact with the ball or `None`."""
+ return self._last_hit
+
+ @property
+ def hit(self):
+ """Indicates if the ball is hit during the last simulation step.
+
+ For a timeline shown below:
+ ..., agent.step, simulation, agent.step, ...
+
+ Returns:
+ True: if the ball is hit by a registered player during simulation step.
+ False: if not.
+ """
+ return self._hit
+
+ @property
+ def repossessed(self):
+ """Indicates if the ball has been repossessed by a different player.
+
+ For a timeline shown below:
+ ..., agent.step, simulation, agent.step, ...
+
+ Returns:
+ True if the ball is hit by a registered player during simulation step
+ and that player is different from `last_hit`.
+ False: if the ball is not hit, or the ball is hit by `last_hit` player.
+ """
+ return self._repossessed
+
+ @property
+ def intercepted(self):
+ """Indicates if the ball has been intercepted by a different team.
+
+ For a timeline shown below:
+ ..., agent.step, simulation, agent.step, ...
+
+ Returns:
+ True: if the ball is hit for the first time, or repossessed by an player
+ from a different team.
+ False: if the ball is not hit, not repossessed, or repossessed by a
+ teammate to `last_hit`.
+ """
+ return self._intercepted
+
+ @property
+ def dist_between_last_hits(self):
+ """Distance between last consecutive hits.
+
+ Returns:
+ Distance between last two consecutive hit events or `None` if there has
+ not been two consecutive hits on the ball.
+ """
+ return self._dist_between_last_hits
diff --git a/dm_control/locomotion/soccer/soccer_ball_test.py b/dm_control/locomotion/soccer/soccer_ball_test.py
new file mode 100644
index 00000000..ceb4bc79
--- /dev/null
+++ b/dm_control/locomotion/soccer/soccer_ball_test.py
@@ -0,0 +1,73 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.soccer_ball."""
+
+
+from absl.testing import absltest
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.entities import props
+from dm_control.locomotion.soccer import soccer_ball
+from dm_control.locomotion.soccer import team
+import numpy as np
+
+
+class SoccerBallTest(absltest.TestCase):
+
+ def test_detect_hit(self):
+ arena = composer.Arena()
+ ball = soccer_ball.SoccerBall(radius=0.35, mass=0.045, name='test_ball')
+ player = team.Player(
+ team=team.Team.HOME,
+ walker=props.Primitive(geom_type='sphere', size=(0.1,), name='home'))
+ arena.add_free_entity(player.walker)
+ ball.register_player(player)
+ arena.add_free_entity(ball)
+
+ random_state = np.random.RandomState(42)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ physics.step()
+
+ ball.initialize_episode(physics, random_state)
+ ball.before_step(physics, random_state)
+ self.assertEqual(ball.hit, False)
+ self.assertEqual(ball.repossessed, False)
+ self.assertEqual(ball.intercepted, False)
+ self.assertIsNone(ball.last_hit)
+ self.assertIsNone(ball.dist_between_last_hits)
+
+ ball.after_substep(physics, random_state)
+ ball.after_step(physics, random_state)
+
+ self.assertEqual(ball.hit, True)
+ self.assertEqual(ball.repossessed, True)
+ self.assertEqual(ball.intercepted, True)
+ self.assertEqual(ball.last_hit, player)
+ # Only one hit registered.
+ self.assertIsNone(ball.dist_between_last_hits)
+
+ def test_has_tracking_cameras(self):
+ ball = soccer_ball.SoccerBall(radius=0.35, mass=0.045, name='test_ball')
+ expected_camera_names = ['ball_cam_near', 'ball_cam', 'ball_cam_far']
+ camera_names = [cam.name for cam in ball.mjcf_model.find_all('camera')]
+ self.assertCountEqual(expected_camera_names, camera_names)
+
+ def test_damp_ratio_is_valid(self):
+ with self.assertRaisesRegex(ValueError, 'Invalid `damp_ratio`.*'):
+ soccer_ball.SoccerBall(damp_ratio=-0.5)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/soccer/task.py b/dm_control/locomotion/soccer/task.py
new file mode 100644
index 00000000..9e75b877
--- /dev/null
+++ b/dm_control/locomotion/soccer/task.py
@@ -0,0 +1,267 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+""""A task where players play a soccer game."""
+
+
+from dm_control import composer
+from dm_control.locomotion.soccer import initializers
+from dm_control.locomotion.soccer import observables as observables_lib
+from dm_control.locomotion.soccer import soccer_ball
+from dm_env import specs
+import numpy as np
+
+_THROW_IN_BALL_Z = 0.5
+
+
+def _disable_geom_contacts(entities):
+ for entity in entities:
+ mjcf_model = entity.mjcf_model
+ for geom in mjcf_model.find_all("geom"):
+ geom.set_attributes(contype=0)
+
+
+class Task(composer.Task):
+ """A task where two teams of walkers play soccer."""
+
+ def __init__(
+ self,
+ players,
+ arena,
+ ball=None,
+ initializer=None,
+ observables=None,
+ disable_walker_contacts=False,
+ nconmax_per_player=200,
+ njmax_per_player=400,
+ control_timestep=0.025,
+ tracking_cameras=(),
+ ):
+ """Construct an instance of soccer.Task.
+
+ This task implements the high-level game logic of multi-agent MuJoCo soccer.
+
+ Args:
+ players: a sequence of `soccer.Player` instances, representing
+ participants to the game from both teams.
+ arena: an instance of `soccer.Pitch`, implementing the physical geoms and
+ the sensors associated with the pitch.
+ ball: optional instance of `soccer.SoccerBall`, implementing the physical
+ geoms and sensors associated with the soccer ball. If None, defaults to
+ using `soccer_ball.SoccerBall()`.
+ initializer: optional instance of `soccer.Initializer` that initializes
+ the task at the start of each episode. If None, defaults to
+ `initializers.UniformInitializer()`.
+ observables: optional instance of `soccer.ObservablesAdder` that adds
+ observables for each player. If None, defaults to
+ `observables.CoreObservablesAdder()`.
+ disable_walker_contacts: if `True`, disable physical contacts between
+ players.
+ nconmax_per_player: allocated maximum number of contacts per player. It
+ may be necessary to increase this value if you encounter errors due to
+ `mjWARN_CONTACTFULL`.
+ njmax_per_player: allocated maximum number of scalar constraints per
+ player. It may be necessary to increase this value if you encounter
+ errors due to `mjWARN_CNSTRFULL`.
+ control_timestep: control timestep of the agent.
+ tracking_cameras: a sequence of `camera.MultiplayerTrackingCamera`
+ instances to track the players and ball.
+ """
+ self.arena = arena
+ self.players = players
+
+ self._initializer = initializer or initializers.UniformInitializer()
+ self._observables = observables or observables_lib.CoreObservablesAdder()
+
+ if disable_walker_contacts:
+ _disable_geom_contacts([p.walker for p in self.players])
+
+ # Create ball and attach ball to arena.
+ self.ball = ball or soccer_ball.SoccerBall()
+ self.arena.add_free_entity(self.ball)
+ self.arena.register_ball(self.ball)
+
+ # Register soccer ball contact tracking for players.
+ for player in self.players:
+ player.walker.create_root_joints(self.arena.attach(player.walker))
+ self.ball.register_player(player)
+ # Add per-walkers observables.
+ self._observables(self, player)
+
+ self._tracking_cameras = tracking_cameras
+
+ self.set_timesteps(
+ physics_timestep=0.005, control_timestep=control_timestep)
+ self.root_entity.mjcf_model.size.nconmax = nconmax_per_player * len(players)
+ self.root_entity.mjcf_model.size.njmax = njmax_per_player * len(players)
+
+ @property
+ def observables(self):
+ observables = []
+ for player in self.players:
+ observables.append(
+ player.walker.observables.as_dict(fully_qualified=False))
+ return observables
+
+ def _throw_in(self, physics, random_state, ball):
+ x, y, _ = physics.bind(ball.geom).xpos
+ shrink_x, shrink_y = random_state.uniform([0.7, 0.7], [0.9, 0.9])
+ ball.set_pose(physics, [x * shrink_x, y * shrink_y, _THROW_IN_BALL_Z])
+ ball.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+ ball.initialize_entity_trackers()
+
+ def _tracked_entity_positions(self, physics):
+ """Return a list of the positions of the ball and all players."""
+ ball_pos, unused_ball_quat = self.ball.get_pose(physics)
+ entity_positions = [ball_pos]
+ for player in self.players:
+ walker_pos, unused_walker_quat = player.walker.get_pose(physics)
+ entity_positions.append(walker_pos)
+ return entity_positions
+
+ def after_compile(self, physics, random_state):
+ super().after_compile(physics, random_state)
+ for camera in self._tracking_cameras:
+ camera.after_compile(physics)
+
+ def after_step(self, physics, random_state):
+ super().after_step(physics, random_state)
+ for camera in self._tracking_cameras:
+ camera.after_step(self._tracked_entity_positions(physics))
+
+ def initialize_episode_mjcf(self, random_state):
+ self.arena.initialize_episode_mjcf(random_state)
+
+ def initialize_episode(self, physics, random_state):
+ self.arena.initialize_episode(physics, random_state)
+ for player in self.players:
+ player.walker.reinitialize_pose(physics, random_state)
+
+ self._initializer(self, physics, random_state)
+ for camera in self._tracking_cameras:
+ camera.initialize_episode(self._tracked_entity_positions(physics))
+
+ @property
+ def root_entity(self):
+ return self.arena
+
+ def get_reward(self, physics):
+ """Returns a list of per-player rewards.
+
+ Each player will receive a reward of:
+ +1 if their team scored a goal
+ -1 if their team conceded a goal
+ 0 if no goals were scored on this timestep.
+
+ Note: the observations also contain various environment statistics that may
+ be used to derive per-player rewards (as done in
+ http://arxiv.org/abs/1902.07151).
+
+ Args:
+ physics: An instance of `Physics`.
+
+ Returns:
+ A list of 0-dimensional numpy arrays, one per player.
+ """
+ scoring_team = self.arena.detected_goal()
+ if not scoring_team:
+ return [np.zeros((), dtype=np.float32) for _ in self.players]
+
+ rewards = []
+ for p in self.players:
+ if p.team == scoring_team:
+ rewards.append(np.ones((), dtype=np.float32))
+ else:
+ rewards.append(-np.ones((), dtype=np.float32))
+ return rewards
+
+ def get_reward_spec(self):
+ return [
+ specs.Array(name="reward", shape=(), dtype=np.float32)
+ for _ in self.players
+ ]
+
+ def get_discount(self, physics):
+ if self.arena.detected_goal():
+ return np.zeros((), np.float32)
+ return np.ones((), np.float32)
+
+ def get_discount_spec(self):
+ return specs.Array(name="discount", shape=(), dtype=np.float32)
+
+ def should_terminate_episode(self, physics):
+ """Returns True if a goal was scored by either team."""
+ return self.arena.detected_goal() is not None
+
+ def before_step(self, physics, actions, random_state):
+ for player, action in zip(self.players, actions):
+ player.walker.apply_action(physics, action, random_state)
+
+ if self.arena.detected_off_court():
+ self._throw_in(physics, random_state, self.ball)
+
+ def action_spec(self, physics):
+ """Return multi-agent action_spec."""
+ return [player.walker.action_spec for player in self.players]
+
+
+class MultiturnTask(Task):
+ """Continuous game play through scoring events until timeout."""
+
+ def __init__(self,
+ players,
+ arena,
+ ball=None,
+ initializer=None,
+ observables=None,
+ disable_walker_contacts=False,
+ nconmax_per_player=200,
+ njmax_per_player=400,
+ control_timestep=0.025,
+ tracking_cameras=()):
+ """See base class."""
+ super().__init__(
+ players,
+ arena,
+ ball=ball,
+ initializer=initializer,
+ observables=observables,
+ disable_walker_contacts=disable_walker_contacts,
+ nconmax_per_player=nconmax_per_player,
+ njmax_per_player=njmax_per_player,
+ control_timestep=control_timestep,
+ tracking_cameras=tracking_cameras)
+
+ # If `True`, reset ball entity trackers before the next step.
+ self._should_reset = False
+
+ def should_terminate_episode(self, physics):
+ return False
+
+ def get_discount(self, physics):
+ return np.ones((), np.float32)
+
+ def before_step(self, physics, actions, random_state):
+ super(MultiturnTask, self).before_step(physics, actions, random_state)
+ if self._should_reset:
+ self.ball.initialize_entity_trackers()
+ self._should_reset = False
+
+ def after_step(self, physics, random_state):
+ super(MultiturnTask, self).after_step(physics, random_state)
+ if self.arena.detected_goal():
+ self._initializer(self, physics, random_state)
+ self._should_reset = True
diff --git a/dm_control/locomotion/soccer/task_test.py b/dm_control/locomotion/soccer/task_test.py
new file mode 100644
index 00000000..617a5d11
--- /dev/null
+++ b/dm_control/locomotion/soccer/task_test.py
@@ -0,0 +1,623 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.tasks.soccer."""
+
+import unittest
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.locomotion import soccer
+from dm_control.locomotion.soccer import camera
+from dm_control.locomotion.soccer import initializers
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+RGBA_BLUE = [.1, .1, .8, 1.]
+RGBA_RED = [.8, .1, .1, 1.]
+
+
+def _walker(name, walker_id, marker_rgba):
+ return soccer.BoxHead(
+ name=name,
+ walker_id=walker_id,
+ marker_rgba=marker_rgba,
+ )
+
+
+def _team_players(team_size, team, team_name, team_color):
+ team_of_players = []
+ for i in range(team_size):
+ team_of_players.append(
+ soccer.Player(team, _walker("%s%d" % (team_name, i), i, team_color)))
+ return team_of_players
+
+
+def _home_team(team_size):
+ return _team_players(team_size, soccer.Team.HOME, "home", RGBA_BLUE)
+
+
+def _away_team(team_size):
+ return _team_players(team_size, soccer.Team.AWAY, "away", RGBA_RED)
+
+
+def _env(players, disable_walker_contacts=True, observables=None,
+ random_state=42, **task_kwargs):
+ return composer.Environment(
+ task=soccer.Task(
+ players=players,
+ arena=soccer.Pitch((20, 15)),
+ observables=observables,
+ disable_walker_contacts=disable_walker_contacts,
+ **task_kwargs
+ ),
+ random_state=random_state,
+ time_limit=1)
+
+
+def _observables_adder(observables_adder):
+ if observables_adder == "core":
+ return soccer.CoreObservablesAdder()
+ if observables_adder == "core_interception":
+ return soccer.MultiObservablesAdder(
+ [soccer.CoreObservablesAdder(),
+ soccer.InterceptionObservablesAdder()])
+ raise ValueError("Unrecognized observable_adder %s" % observables_adder)
+
+
+class TaskTest(parameterized.TestCase):
+
+ def _assert_all_count_equal(self, list_of_lists):
+ """Check all lists in the list are count equal."""
+ if not list_of_lists:
+ return
+
+ first = sorted(list_of_lists[0])
+ for other in list_of_lists[1:]:
+ self.assertCountEqual(first, other)
+
+ @parameterized.named_parameters(
+ ("1vs1_core", 1, "core", 33, True),
+ ("2vs2_core", 2, "core", 43, True),
+ ("1vs1_interception", 1, "core_interception", 41, True),
+ ("2vs2_interception", 2, "core_interception", 51, True),
+ ("1vs1_core_contact", 1, "core", 33, False),
+ ("2vs2_core_contact", 2, "core", 43, False),
+ ("1vs1_interception_contact", 1, "core_interception", 41, False),
+ ("2vs2_interception_contact", 2, "core_interception", 51, False),
+ )
+ def test_step_environment(self, team_size, observables_adder, num_obs,
+ disable_walker_contacts):
+ env = _env(
+ _home_team(team_size) + _away_team(team_size),
+ observables=_observables_adder(observables_adder),
+ disable_walker_contacts=disable_walker_contacts)
+ self.assertLen(env.action_spec(), 2 * team_size)
+ self.assertLen(env.observation_spec(), 2 * team_size)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ for observation, spec in zip(timestep.observation, env.observation_spec()):
+ self.assertLen(spec, num_obs)
+ self.assertCountEqual(list(observation.keys()), list(spec.keys()))
+ for key in observation.keys():
+ self.assertEqual(observation[key].shape, spec[key].shape)
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ # TODO(b/124848293): consolidate environment stepping loop for task tests.
+ @parameterized.named_parameters(
+ ("1vs2", 1, 2, 38),
+ ("2vs1", 2, 1, 38),
+ ("3vs0", 3, 0, 38),
+ ("0vs2", 0, 2, 33),
+ ("2vs2", 2, 2, 43),
+ ("0vs0", 0, 0, None),
+ )
+ def test_num_players(self, home_size, away_size, num_observations):
+ env = _env(_home_team(home_size) + _away_team(away_size))
+ self.assertLen(env.action_spec(), home_size + away_size)
+ self.assertLen(env.observation_spec(), home_size + away_size)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ # Members of the same team should have identical specs.
+ self._assert_all_count_equal(
+ [spec.keys() for spec in env.observation_spec()[:home_size]])
+ self._assert_all_count_equal(
+ [spec.keys() for spec in env.observation_spec()[-away_size:]])
+
+ for observation, spec in zip(timestep.observation, env.observation_spec()):
+ self.assertCountEqual(list(observation.keys()), list(spec.keys()))
+ for key in observation.keys():
+ self.assertEqual(observation[key].shape, spec[key].shape)
+
+ self.assertLen(spec, num_observations)
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ self.assertLen(timestep.observation, home_size + away_size)
+
+ self.assertLen(timestep.reward, home_size + away_size)
+ for player_spec, player_reward in zip(env.reward_spec(), timestep.reward):
+ player_spec.validate(player_reward)
+
+ discount_spec = env.discount_spec()
+ discount_spec.validate(timestep.discount)
+
+ def test_all_contacts(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _all_contact_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 5.])
+ ball.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = 0., 0., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 3.], quat)
+ walkers[0].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = 0., 0., np.pi / 3. + np.pi
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 1.], quat)
+ walkers[1].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ env.add_extra_hook("initialize_episode", _all_contact_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ def test_symmetric_observations(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _symmetric_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ ball.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = 5., 3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = -5., -3., np.pi / 3. + np.pi
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ env.add_extra_hook("initialize_episode", _symmetric_configuration)
+
+ timestep = env.reset()
+ obs_a, obs_b = timestep.observation
+ self.assertCountEqual(list(obs_a.keys()), list(obs_b.keys()))
+ for k in sorted(obs_a.keys()):
+ o_a, o_b = obs_a[k], obs_b[k]
+ self.assertTrue(
+ np.allclose(o_a, o_b) or np.allclose(o_a, -o_b),
+ k + " not equal:" + str(o_a) + ";" + str(o_b))
+
+ def test_symmetric_dynamic_observations(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _symmetric_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ # Ball shooting up. Walkers going tangent.
+ ball.set_velocity(physics, velocity=[0., 0., 1.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = 5., 3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(physics, velocity=[y, -x, 0.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = -5., -3., np.pi / 3. + np.pi
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(physics, velocity=[y, -x, 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _symmetric_configuration)
+
+ timestep = env.reset()
+ obs_a, obs_b = timestep.observation
+ self.assertCountEqual(list(obs_a.keys()), list(obs_b.keys()))
+ for k in sorted(obs_a.keys()):
+ o_a, o_b = obs_a[k], obs_b[k]
+ self.assertTrue(
+ np.allclose(o_a, o_b) or np.allclose(o_a, -o_b),
+ k + " not equal:" + str(o_a) + ";" + str(o_b))
+
+ def test_prev_actions(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ actions = []
+ for i, player in enumerate(env.task.players):
+ spec = player.walker.action_spec
+ actions.append((i + 1) * np.ones(spec.shape, dtype=spec.dtype))
+
+ env.reset()
+ timestep = env.step(actions)
+
+ for walker_idx, obs in enumerate(timestep.observation):
+ np.testing.assert_allclose(
+ np.squeeze(obs["prev_action"], axis=0),
+ actions[walker_idx],
+ err_msg="Walker {}: incorrect previous action.".format(walker_idx))
+
+ @parameterized.named_parameters(
+ dict(testcase_name="1vs2_draw",
+ home_size=1, away_size=2, ball_vel_x=0, expected_home_score=0),
+ dict(testcase_name="1vs2_home_score",
+ home_size=1, away_size=2, ball_vel_x=50, expected_home_score=1),
+ dict(testcase_name="2vs1_away_score",
+ home_size=2, away_size=1, ball_vel_x=-50, expected_home_score=-1),
+ dict(testcase_name="3vs0_home_score",
+ home_size=3, away_size=0, ball_vel_x=50, expected_home_score=1),
+ dict(testcase_name="0vs2_home_score",
+ home_size=0, away_size=2, ball_vel_x=50, expected_home_score=1),
+ dict(testcase_name="2vs2_away_score",
+ home_size=2, away_size=2, ball_vel_x=-50, expected_home_score=-1),
+ )
+ def test_scoring_rewards(
+ self, home_size, away_size, ball_vel_x, expected_home_score):
+ env = _env(_home_team(home_size) + _away_team(away_size))
+
+ def _score_configuration(physics, random_state):
+ del random_state # Unused.
+ # Send the ball shooting towards either the home or away goal.
+ env.task.ball.set_pose(physics, [0., 0., 0.5])
+ env.task.ball.set_velocity(physics,
+ velocity=[ball_vel_x, 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _score_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ # Disable contacts and gravity so that the ball follows a straight path.
+ with env.physics.model.disable("contact", "gravity"):
+
+ timestep = env.reset()
+ with self.subTest("Reward and discount are None on the first timestep"):
+ self.assertTrue(timestep.first())
+ self.assertIsNone(timestep.reward)
+ self.assertIsNone(timestep.discount)
+
+ # Step until the episode ends.
+ timestep = env.step(actions)
+ while not timestep.last():
+ self.assertTrue(timestep.mid())
+ # For non-terminal timesteps, the reward should always be 0 and the
+ # discount should always be 1.
+ np.testing.assert_array_equal(np.hstack(timestep.reward), 0.)
+ self.assertEqual(timestep.discount, 1.)
+ timestep = env.step(actions)
+
+ # If a goal was scored then the epsiode should have ended with a discount of
+ # 0. If neither team scored and the episode ended due to hitting the time
+ # limit then the discount should be 1.
+ with self.subTest("Correct terminal discount"):
+ if expected_home_score != 0:
+ expected_discount = 0.
+ else:
+ expected_discount = 1.
+ self.assertEqual(timestep.discount, expected_discount)
+
+ with self.subTest("Correct terminal reward"):
+ reward = np.hstack(timestep.reward)
+ np.testing.assert_array_equal(reward[:home_size], expected_home_score)
+ np.testing.assert_array_equal(reward[home_size:], -expected_home_score)
+
+ def test_throw_in(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _throw_in_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 3., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ # Ball shooting out of bounds.
+ ball.set_velocity(physics, velocity=[0., 50., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = 0., -3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+ x, y, rotation = 0., -5., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _throw_in_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ terminal_ball_vel = np.linalg.norm(
+ timestep.observation[0]["ball_ego_linear_velocity"])
+ self.assertAlmostEqual(terminal_ball_vel, 0.)
+
+ @parameterized.named_parameters(("score", 50., 0.), ("timeout", 0., 1.))
+ def test_terminal_discount(self, init_ball_vel_x, expected_terminal_discount):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _initial_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ # Ball shooting up. Walkers going tangent.
+ ball.set_velocity(physics, velocity=[init_ball_vel_x, 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = 0., -3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+ x, y, rotation = 0., 3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _initial_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ self.assertEqual(timestep.discount, expected_terminal_discount)
+
+ @parameterized.named_parameters(("reset_only", False), ("step", True))
+ def test_render(self, take_step):
+ height = 100
+ width = 150
+ tracking_cameras = []
+ for min_distance in [1, 1, 2]:
+ tracking_cameras.append(
+ camera.MultiplayerTrackingCamera(
+ min_distance=min_distance,
+ distance_factor=1,
+ smoothing_update_speed=0.1,
+ width=width,
+ height=height,
+ ))
+ env = _env(_home_team(1) + _away_team(1), tracking_cameras=tracking_cameras)
+ env.reset()
+ if take_step:
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+ env.step(actions)
+ rendered_frames = [cam.render() for cam in tracking_cameras]
+ for frame in rendered_frames:
+ assert frame.shape == (height, width, 3)
+ self.assertTrue(np.array_equal(rendered_frames[0], rendered_frames[1]))
+ self.assertFalse(np.array_equal(rendered_frames[1], rendered_frames[2]))
+
+
+class UniformInitializerTest(parameterized.TestCase):
+
+ @parameterized.parameters([0.3, 0.7])
+ def test_walker_position(self, spawn_ratio):
+ initializer = initializers.UniformInitializer(spawn_ratio=spawn_ratio)
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+ root_bodies = [p.walker.root_body for p in env.task.players]
+ xy_bounds = np.asarray(env.task.arena.size) * spawn_ratio
+ env.reset()
+ xy = env.physics.bind(root_bodies).xpos[:, :2].copy()
+ with self.subTest("X and Y positions within bounds"):
+ if np.any(abs(xy) > xy_bounds):
+ self.fail("Walker(s) spawned out of bounds. Expected abs(xy) "
+ "<= {}, got:\n{}".format(xy_bounds, xy))
+ env.reset()
+ xy2 = env.physics.bind(root_bodies).xpos[:, :2].copy()
+ with self.subTest("X and Y positions change after reset"):
+ if np.any(xy == xy2):
+ self.fail("Walker(s) have the same X and/or Y coordinates before and "
+ "after reset. Before: {}, after: {}.".format(xy, xy2))
+
+ def test_walker_rotation(self):
+ initializer = initializers.UniformInitializer()
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+
+ def quats_to_eulers(quats):
+ eulers = np.empty((len(quats), 3), dtype=np.double)
+ dt = 1.
+ for i, quat in enumerate(quats):
+ mjbindings.mjlib.mju_quat2Vel(eulers[i], quat, dt)
+ return eulers
+
+ # TODO(b/132671988): Switch to using `get_pose` to get the quaternion once
+ # `BoxHead.get_pose` and `BoxHead.set_pose` are
+ # implemented in a consistent way.
+ def get_quat(walker):
+ return env.physics.bind(walker.root_body).xquat
+
+ env.reset()
+ quats = [get_quat(p.walker) for p in env.task.players]
+ eulers = quats_to_eulers(quats)
+ with self.subTest("Rotation is about the Z-axis only"):
+ np.testing.assert_array_equal(eulers[:, :2], 0.)
+
+ env.reset()
+ quats2 = [get_quat(p.walker) for p in env.task.players]
+ eulers2 = quats_to_eulers(quats2)
+ with self.subTest("Rotation about Z changes after reset"):
+ if np.any(eulers[:, 2] == eulers2[:, 2]):
+ self.fail("Walker(s) have the same rotation about Z before and "
+ "after reset. Before: {}, after: {}."
+ .format(eulers[:, 2], eulers2[:, 2]))
+
+ # TODO(b/132759890): Remove `expectedFailure` decorator once `set_velocity`
+ # works correctly for the `BoxHead` walker.
+ @unittest.expectedFailure
+ def test_walker_velocity(self):
+ initializer = initializers.UniformInitializer()
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+ root_joints = []
+ non_root_joints = []
+ for player in env.task.players:
+ attachment_frame = mjcf.get_attachment_frame(player.walker.mjcf_model)
+ root_joints.extend(
+ attachment_frame.find_all("joint", immediate_children_only=True))
+ non_root_joints.extend(player.walker.mjcf_model.find_all("joint"))
+ # Assign a non-zero sentinel value to the velocities of all root and
+ # non-root joints.
+ sentinel_velocity = 3.14
+ env.physics.bind(root_joints + non_root_joints).qvel = sentinel_velocity
+ # The initializer should zero the velocities of the root joints, but not the
+ # non-root joints.
+ initializer(env.task, env.physics, env.random_state)
+ np.testing.assert_array_equal(env.physics.bind(non_root_joints).qvel,
+ sentinel_velocity)
+ np.testing.assert_array_equal(env.physics.bind(root_joints).qvel, 0.)
+
+ @parameterized.parameters([
+ dict(spawn_ratio=0.3, init_ball_z=0.4),
+ dict(spawn_ratio=0.5, init_ball_z=0.6),
+ ])
+ def test_ball_position(self, spawn_ratio, init_ball_z):
+ initializer = initializers.UniformInitializer(
+ spawn_ratio=spawn_ratio, init_ball_z=init_ball_z)
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+ xy_bounds = np.asarray(env.task.arena.size) * spawn_ratio
+ env.reset()
+ position, _ = env.task.ball.get_pose(env.physics)
+ xyz = position.copy()
+ with self.subTest("X and Y positions within bounds"):
+ if np.any(abs(xyz[:2]) > xy_bounds):
+ self.fail("Ball spawned out of bounds. Expected abs(xy) "
+ "<= {}, got:\n{}".format(xy_bounds, xyz[:2]))
+ with self.subTest("Z position equal to `init_ball_z`"):
+ self.assertEqual(xyz[2], init_ball_z)
+ env.reset()
+ position, _ = env.task.ball.get_pose(env.physics)
+ xyz2 = position.copy()
+ with self.subTest("X and Y positions change after reset"):
+ if np.any(xyz[:2] == xyz2[:2]):
+ self.fail("Ball has the same XY position before and after reset. "
+ "Before: {}, after: {}.".format(xyz[:2], xyz2[:2]))
+
+ def test_ball_velocity(self):
+ initializer = initializers.UniformInitializer()
+ env = _env(_home_team(1) + _away_team(1), initializer=initializer)
+ ball_root_joint = mjcf.get_frame_freejoint(env.task.ball.mjcf_model)
+ # Set the velocities of the ball root joint to a non-zero sentinel value.
+ env.physics.bind(ball_root_joint).qvel = 3.14
+ initializer(env.task, env.physics, env.random_state)
+ # The initializer should set the ball velocity to zero.
+ ball_velocity = env.physics.bind(ball_root_joint).qvel
+ np.testing.assert_array_equal(ball_velocity, 0.)
+
+
+class _ScoringInitializer(soccer.Initializer):
+ """Initialize the ball for home team to repeatedly score goals."""
+
+ def __init__(self):
+ self._num_calls = 0
+
+ @property
+ def num_calls(self):
+ return self._num_calls
+
+ def __call__(self, task, physics, random_state):
+ # Initialize `ball` along the y-axis with a positive y-velocity.
+ task.ball.set_pose(physics, [2.0, 0.0, 1.5])
+ task.ball.set_velocity(
+ physics, velocity=[100.0, 0.0, 0.0], angular_velocity=0.)
+ for i, player in enumerate(task.players):
+ player.walker.reinitialize_pose(physics, random_state)
+ (_, _, z), quat = player.walker.get_pose(physics)
+ player.walker.set_pose(physics, [-i * 5, 0.0, z], quat)
+ player.walker.set_velocity(physics, velocity=0., angular_velocity=0.)
+
+ self._num_calls += 1
+
+
+class MultiturnTaskTest(parameterized.TestCase):
+
+ def test_multiple_goals(self):
+ initializer = _ScoringInitializer()
+ time_limit = 1.0
+ control_timestep = 0.025
+ env = composer.Environment(
+ task=soccer.MultiturnTask(
+ players=_home_team(1) + _away_team(1),
+ arena=soccer.Pitch((20, 15), field_box=True), # disable throw-in.
+ initializer=initializer,
+ control_timestep=control_timestep),
+ time_limit=time_limit)
+
+ timestep = env.reset()
+ num_steps = 0
+ rewards = [np.zeros(s.shape, s.dtype) for s in env.reward_spec()]
+ while not timestep.last():
+ timestep = env.step([spec.generate_value() for spec in env.action_spec()])
+ for reward, r_t in zip(rewards, timestep.reward):
+ reward += r_t
+ num_steps += 1
+ self.assertEqual(num_steps, time_limit / control_timestep)
+
+ num_scores = initializer.num_calls - 1 # discard initialization.
+ self.assertEqual(num_scores, 6)
+ self.assertEqual(rewards, [
+ np.full((), num_scores, np.float32),
+ np.full((), -num_scores, np.float32)
+ ])
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/locomotion/soccer/team.py b/dm_control/locomotion/soccer/team.py
new file mode 100644
index 00000000..2cae0ecd
--- /dev/null
+++ b/dm_control/locomotion/soccer/team.py
@@ -0,0 +1,31 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Define teams and players participating in a match."""
+
+import collections
+import enum
+
+
+class Team(enum.Enum):
+ HOME = 0
+ AWAY = 1
+
+
+RGBA_BLUE = [.1, .1, .8, 1.]
+RGBA_RED = [.8, .1, .1, 1.]
+
+
+Player = collections.namedtuple('Player', ['team', 'walker'])
diff --git a/dm_control/locomotion/tasks/__init__.py b/dm_control/locomotion/tasks/__init__.py
new file mode 100644
index 00000000..9605995f
--- /dev/null
+++ b/dm_control/locomotion/tasks/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tasks in the Locomotion library."""
+
+
+from dm_control.locomotion.tasks.corridors import RunThroughCorridor
+from dm_control.locomotion.tasks.escape import Escape
+# Import1 removed.
+# Import2 removed.
+from dm_control.locomotion.tasks.go_to_target import GoToTarget
+from dm_control.locomotion.tasks.random_goal_maze import ManyGoalsMaze
+from dm_control.locomotion.tasks.random_goal_maze import ManyHeterogeneousGoalsMaze
+from dm_control.locomotion.tasks.random_goal_maze import RepeatSingleGoalMaze
+from dm_control.locomotion.tasks.random_goal_maze import RepeatSingleGoalMazeAugmentedWithTargets
+from dm_control.locomotion.tasks.reach import TwoTouch
diff --git a/dm_control/locomotion/tasks/corridors.py b/dm_control/locomotion/tasks/corridors.py
new file mode 100644
index 00000000..16eb4762
--- /dev/null
+++ b/dm_control/locomotion/tasks/corridors.py
@@ -0,0 +1,158 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Corridor-based locomotion tasks."""
+
+
+from dm_control import composer
+from dm_control.composer import variation
+from dm_control.utils import rewards
+import numpy as np
+
+
+class RunThroughCorridor(composer.Task):
+ """A task that requires a walker to run through a corridor.
+
+ This task rewards an agent for controlling a walker to move at a specific
+ target velocity along the corridor, and for minimising the magnitude of the
+ control signals used to achieve this.
+ """
+
+ def __init__(self,
+ walker,
+ arena,
+ walker_spawn_position=(0, 0, 0),
+ walker_spawn_rotation=None,
+ target_velocity=3.0,
+ contact_termination=True,
+ terminate_at_height=-0.5,
+ physics_timestep=0.005,
+ control_timestep=0.025):
+ """Initializes this task.
+
+ Args:
+ walker: an instance of `locomotion.walkers.base.Walker`.
+ arena: an instance of `locomotion.arenas.corridors.Corridor`.
+ walker_spawn_position: a sequence of 3 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the walker is spawned at the beginning of an episode.
+ walker_spawn_rotation: a number, or a `composer.Variation` instance that
+ generates a number, specifying the yaw angle offset (in radians) that is
+ applied to the walker at the beginning of an episode.
+ target_velocity: a number specifying the target velocity (in meters per
+ second) for the walker.
+ contact_termination: whether to terminate if a non-foot geom touches the
+ ground.
+ terminate_at_height: a number specifying the height of end effectors below
+ which the episode terminates.
+ physics_timestep: a number specifying the timestep (in seconds) of the
+ physics simulation.
+ control_timestep: a number specifying the timestep (in seconds) at which
+ the agent applies its control inputs (in seconds).
+ """
+
+ self._arena = arena
+ self._walker = walker
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+ self._walker_spawn_position = walker_spawn_position
+ self._walker_spawn_rotation = walker_spawn_rotation
+
+ enabled_observables = []
+ enabled_observables += self._walker.observables.proprioception
+ enabled_observables += self._walker.observables.kinematic_sensors
+ enabled_observables += self._walker.observables.dynamic_sensors
+ enabled_observables.append(self._walker.observables.sensors_touch)
+ enabled_observables.append(self._walker.observables.egocentric_camera)
+ for observable in enabled_observables:
+ observable.enabled = True
+
+ self._vel = target_velocity
+ self._contact_termination = contact_termination
+ self._terminate_at_height = terminate_at_height
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def initialize_episode_mjcf(self, random_state):
+ self._arena.regenerate(random_state)
+ self._arena.mjcf_model.visual.map.znear = 0.00025
+ self._arena.mjcf_model.visual.map.zfar = 4.
+
+ def initialize_episode(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+ if self._walker_spawn_rotation:
+ rotation = variation.evaluate(
+ self._walker_spawn_rotation, random_state=random_state)
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ else:
+ quat = None
+ self._walker.shift_pose(
+ physics,
+ position=variation.evaluate(
+ self._walker_spawn_position, random_state=random_state),
+ quaternion=quat,
+ rotate_velocity=True)
+
+ self._failure_termination = False
+ walker_foot_geoms = set(self._walker.ground_contact_geoms)
+ walker_nonfoot_geoms = [
+ geom for geom in self._walker.mjcf_model.find_all('geom')
+ if geom not in walker_foot_geoms]
+ self._walker_nonfoot_geomids = set(
+ physics.bind(walker_nonfoot_geoms).element_id)
+ self._ground_geomids = set(
+ physics.bind(self._arena.ground_geoms).element_id)
+
+ def _is_disallowed_contact(self, contact):
+ set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids
+ return ((contact.geom1 in set1 and contact.geom2 in set2) or
+ (contact.geom1 in set2 and contact.geom2 in set1))
+
+ def before_step(self, physics, action, random_state):
+ self._walker.apply_action(physics, action, random_state)
+
+ def after_step(self, physics, random_state):
+ self._failure_termination = False
+ if self._contact_termination:
+ for c in physics.data.contact:
+ if self._is_disallowed_contact(c):
+ self._failure_termination = True
+ break
+ if self._terminate_at_height is not None:
+ if any(physics.bind(self._walker.end_effectors).xpos[:, -1] <
+ self._terminate_at_height):
+ self._failure_termination = True
+
+ def get_reward(self, physics):
+ walker_xvel = physics.bind(self._walker.root_body).subtree_linvel[0]
+ xvel_term = rewards.tolerance(
+ walker_xvel, (self._vel, self._vel),
+ margin=self._vel,
+ sigmoid='linear',
+ value_at_margin=0.0)
+ return xvel_term
+
+ def should_terminate_episode(self, physics):
+ return self._failure_termination
+
+ def get_discount(self, physics):
+ if self._failure_termination:
+ return 0.
+ else:
+ return 1.
diff --git a/dm_control/locomotion/tasks/corridors_test.py b/dm_control/locomotion/tasks/corridors_test.py
new file mode 100644
index 00000000..5266fc58
--- /dev/null
+++ b/dm_control/locomotion/tasks/corridors_test.py
@@ -0,0 +1,134 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.tasks.corridors."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.variation import deterministic
+from dm_control.composer.variation import rotations
+from dm_control.locomotion.arenas import corridors as corridor_arenas
+from dm_control.locomotion.tasks import corridors as corridor_tasks
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+
+
+class CorridorsTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ dict(position_offset=(0, 0, 0),
+ rotate_180_degrees=False,
+ use_variations=False),
+ dict(position_offset=(1, 2, 3),
+ rotate_180_degrees=True,
+ use_variations=True))
+ def test_walker_is_correctly_reinitialized(
+ self, position_offset, rotate_180_degrees, use_variations):
+ walker_spawn_position = position_offset
+
+ if not rotate_180_degrees:
+ walker_spawn_rotation = None
+ else:
+ walker_spawn_rotation = np.pi
+
+ if use_variations:
+ walker_spawn_position = deterministic.Constant(position_offset)
+ walker_spawn_rotation = deterministic.Constant(walker_spawn_rotation)
+
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = corridor_arenas.EmptyCorridor()
+ task = corridor_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=walker_spawn_position,
+ walker_spawn_rotation=walker_spawn_rotation)
+
+ # Randomize the initial pose and joint positions in order to check that they
+ # are set correctly by `initialize_episode`.
+ random_state = np.random.RandomState(12345)
+ task.initialize_episode_mjcf(random_state)
+ physics = mjcf.Physics.from_mjcf_model(task.root_entity.mjcf_model)
+
+ walker_joints = walker.mjcf_model.find_all('joint')
+ physics.bind(walker_joints).qpos = random_state.uniform(
+ size=len(walker_joints))
+ walker.set_pose(physics,
+ position=random_state.uniform(size=3),
+ quaternion=rotations.UniformQuaternion()(random_state))
+
+ task.initialize_episode(physics, random_state)
+ physics.forward()
+
+ with self.subTest('Correct joint positions'):
+ walker_qpos = physics.bind(walker_joints).qpos
+ if walker.upright_pose.qpos is not None:
+ np.testing.assert_array_equal(walker_qpos, walker.upright_pose.qpos)
+ else:
+ walker_qpos0 = physics.bind(walker_joints).qpos0
+ np.testing.assert_array_equal(walker_qpos, walker_qpos0)
+
+ walker_xpos, walker_xquat = walker.get_pose(physics)
+
+ with self.subTest('Correct position'):
+ expected_xpos = walker.upright_pose.xpos + np.array(position_offset)
+ np.testing.assert_array_equal(walker_xpos, expected_xpos)
+
+ with self.subTest('Correct orientation'):
+ upright_xquat = walker.upright_pose.xquat.copy()
+ upright_xquat /= np.linalg.norm(walker.upright_pose.xquat)
+ if rotate_180_degrees:
+ expected_xquat = (-upright_xquat[3], -upright_xquat[2],
+ upright_xquat[1], upright_xquat[0])
+ else:
+ expected_xquat = upright_xquat
+ np.testing.assert_allclose(walker_xquat, expected_xquat)
+
+ def test_termination_and_discount(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = corridor_arenas.EmptyCorridor()
+ task = corridor_tasks.RunThroughCorridor(walker, arena)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ self.assertEqual(task.get_discount(env.physics), 1)
+
+ # Rotate the walker upside down and run the physics until it makes contact.
+ current_time = env.physics.data.time
+ walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0))
+ env.physics.forward()
+ while env.physics.data.ncon == 0:
+ env.physics.step()
+ env.physics.data.time = current_time
+
+ # Should now trigger a failure termination.
+ env.step(zero_action)
+ self.assertTrue(task.should_terminate_episode(env.physics))
+ self.assertEqual(task.get_discount(env.physics), 0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/tasks/escape.py b/dm_control/locomotion/tasks/escape.py
new file mode 100644
index 00000000..cdb9addc
--- /dev/null
+++ b/dm_control/locomotion/tasks/escape.py
@@ -0,0 +1,184 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Escape locomotion tasks."""
+
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable as base_observable
+from dm_control.rl import control
+from dm_control.utils import rewards
+import numpy as np
+
+# Constants related to terrain generation.
+_HEIGHTFIELD_ID = 0
+
+
+class Escape(composer.Task):
+ """A task solved by escaping a starting area (e.g. bowl-shaped terrain)."""
+
+ def __init__(self,
+ walker,
+ arena,
+ walker_spawn_position=(0, 0, 0),
+ walker_spawn_rotation=None,
+ physics_timestep=0.005,
+ control_timestep=0.025):
+ """Initializes this task.
+
+ Args:
+ walker: an instance of `locomotion.walkers.base.Walker`.
+ arena: an instance of `locomotion.arenas`.
+ walker_spawn_position: a sequence of 3 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the walker is spawned at the beginning of an episode.
+ walker_spawn_rotation: a number, or a `composer.Variation` instance that
+ generates a number, specifying the yaw angle offset (in radians) that is
+ applied to the walker at the beginning of an episode.
+ physics_timestep: a number specifying the timestep (in seconds) of the
+ physics simulation.
+ control_timestep: a number specifying the timestep (in seconds) at which
+ the agent applies its control inputs (in seconds).
+ """
+
+ self._arena = arena
+ self._walker = walker
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+ self._walker_spawn_position = walker_spawn_position
+ self._walker_spawn_rotation = walker_spawn_rotation
+
+ enabled_observables = []
+ enabled_observables += self._walker.observables.proprioception
+ enabled_observables += self._walker.observables.kinematic_sensors
+ enabled_observables += self._walker.observables.dynamic_sensors
+ enabled_observables.append(self._walker.observables.sensors_touch)
+ enabled_observables.append(self._walker.observables.egocentric_camera)
+ for observable in enabled_observables:
+ observable.enabled = True
+
+ if 'CMUHumanoid' in str(type(self._walker)):
+ core_body = 'walker/root'
+ self._reward_body = 'walker/root'
+ elif 'Rat' in str(type(self._walker)):
+ core_body = 'walker/torso'
+ self._reward_body = 'walker/head'
+ else:
+ raise ValueError('Expects Rat or CMUHumanoid.')
+
+ def _origin(physics):
+ """Returns origin position in the torso frame."""
+ torso_frame = physics.named.data.xmat[core_body].reshape(3, 3)
+ torso_pos = physics.named.data.xpos[core_body]
+ return -torso_pos.dot(torso_frame)
+
+ self._walker.observables.add_observable(
+ 'origin', base_observable.Generic(_origin))
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def initialize_episode_mjcf(self, random_state):
+ if hasattr(self._arena, 'regenerate'):
+ self._arena.regenerate(random_state)
+ self._arena.mjcf_model.visual.map.znear = 0.00025
+ self._arena.mjcf_model.visual.map.zfar = 50.
+
+ def initialize_episode(self, physics, random_state):
+ super().initialize_episode(physics, random_state)
+
+ # Initial configuration.
+ orientation = random_state.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, self._walker, orientation)
+
+ def get_reward(self, physics):
+ # Escape reward term.
+ terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ escape_reward = rewards.tolerance(
+ np.asarray(np.linalg.norm(
+ physics.named.data.site_xpos[self._reward_body])),
+ bounds=(terrain_size, float('inf')),
+ margin=terrain_size,
+ value_at_margin=0,
+ sigmoid='linear')
+ upright_reward = _upright_reward(physics, self._walker, deviation_angle=30)
+ return upright_reward * escape_reward
+
+ def get_discount(self, physics):
+ return 1.
+
+
+def _find_non_contacting_height(physics, walker, orientation,
+ x_pos=0.0, y_pos=0.0, maxiter=1000):
+ """Find a height with no contacts given a body orientation.
+
+ Args:
+ physics: An instance of `Physics`.
+ walker: the focal walker.
+ orientation: A quaternion.
+ x_pos: A float. Position along global x-axis.
+ y_pos: A float. Position along global y-axis.
+ maxiter: maximum number of iterations to try
+ """
+ z_pos = 0.0 # Start embedded in the floor.
+ num_contacts = 1
+ count = 1
+ # Move up in 1cm increments until no contacts.
+ while num_contacts > 0:
+ try:
+ with physics.reset_context():
+ freejoint = mjcf.get_frame_freejoint(walker.mjcf_model)
+ physics.bind(freejoint).qpos[:3] = x_pos, y_pos, z_pos
+ physics.bind(freejoint).qpos[3:] = orientation
+ except control.PhysicsError:
+ # We may encounter a PhysicsError here due to filling the contact
+ # buffer, in which case we simply increment the height and continue.
+ pass
+ num_contacts = physics.data.ncon
+ z_pos += 0.01
+ count += 1
+ if count > maxiter:
+ raise ValueError(
+ 'maxiter reached: possibly contacts in null pose of body.'
+ )
+
+
+def _upright_reward(physics, walker, deviation_angle=0):
+ """Returns a reward proportional to how upright the torso is.
+
+ Args:
+ physics: an instance of `Physics`.
+ walker: the focal walker.
+ deviation_angle: A float, in degrees. The reward is 0 when the torso is
+ exactly upside-down and 1 when the torso's z-axis is less than
+ `deviation_angle` away from the global z-axis.
+ """
+ deviation = np.cos(np.deg2rad(deviation_angle))
+ upright_torso = physics.bind(walker.root_body).xmat[-1]
+ if hasattr(walker, 'pelvis_body'):
+ upright_pelvis = physics.bind(walker.pelvis_body).xmat[-1]
+ upright_zz = np.stack([upright_torso, upright_pelvis])
+ else:
+ upright_zz = upright_torso
+ upright = rewards.tolerance(upright_zz,
+ bounds=(deviation, float('inf')),
+ sigmoid='linear',
+ margin=1 + deviation,
+ value_at_margin=0)
+ return np.min(upright)
diff --git a/dm_control/locomotion/tasks/escape_test.py b/dm_control/locomotion/tasks/escape_test.py
new file mode 100644
index 00000000..9fb39a8b
--- /dev/null
+++ b/dm_control/locomotion/tasks/escape_test.py
@@ -0,0 +1,85 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.tasks.escape."""
+
+
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import bowl
+from dm_control.locomotion.tasks import escape
+from dm_control.locomotion.walkers import rodent
+import numpy as np
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+class EscapeTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = rodent.Rat()
+
+ # Build a corridor-shaped arena that is obstructed by walls.
+ arena = bowl.Bowl(
+ size=(20., 20.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = escape.Escape(
+ walker=walker,
+ arena=arena,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/joints_pos', timestep.observation)
+
+ def test_contact(self):
+ walker = rodent.Rat()
+
+ # Build a corridor-shaped arena that is obstructed by walls.
+ arena = bowl.Bowl(
+ size=(20., 20.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = escape.Escape(
+ walker=walker,
+ arena=arena,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 1)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/tasks/go_to_target.py b/dm_control/locomotion/tasks/go_to_target.py
new file mode 100644
index 00000000..cd30cf96
--- /dev/null
+++ b/dm_control/locomotion/tasks/go_to_target.py
@@ -0,0 +1,217 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Task for a walker to move to a target."""
+
+
+from dm_control import composer
+from dm_control.composer import variation
+from dm_control.composer.observation import observable
+from dm_control.composer.variation import distributions
+import numpy as np
+
+DEFAULT_DISTANCE_TOLERANCE_TO_TARGET = 1.0
+
+
+class GoToTarget(composer.Task):
+ """A task that requires a walker to move towards a target."""
+
+ def __init__(self,
+ walker,
+ arena,
+ moving_target=False,
+ target_relative=False,
+ target_relative_dist=1.5,
+ steps_before_moving_target=10,
+ distance_tolerance=DEFAULT_DISTANCE_TOLERANCE_TO_TARGET,
+ target_spawn_position=None,
+ walker_spawn_position=None,
+ walker_spawn_rotation=None,
+ physics_timestep=0.005,
+ control_timestep=0.025):
+ """Initializes this task.
+
+ Args:
+ walker: an instance of `locomotion.walkers.base.Walker`.
+ arena: an instance of `locomotion.arenas.floors.Floor`.
+ moving_target: bool, Whether the target should move after receiving the
+ walker reaches it.
+ target_relative: bool, Whether the target be set relative to its current
+ position.
+ target_relative_dist: float, new target distance range if
+ using target_relative.
+ steps_before_moving_target: int, the number of steps before the target
+ moves, if moving_target==True.
+ distance_tolerance: Accepted to distance to the target position before
+ providing reward.
+ target_spawn_position: a sequence of 2 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the target is spawned at the beginning of an episode.
+ If None, the entire arena is used to generate random target positions.
+ walker_spawn_position: a sequence of 2 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the walker is spawned at the beginning of an episode.
+ If None, the entire arena is used to generate random spawn positions.
+ walker_spawn_rotation: a number, or a `composer.Variation` instance that
+ generates a number, specifying the yaw angle offset (in radians) that is
+ applied to the walker at the beginning of an episode.
+ physics_timestep: a number specifying the timestep (in seconds) of the
+ physics simulation.
+ control_timestep: a number specifying the timestep (in seconds) at which
+ the agent applies its control inputs (in seconds).
+ """
+
+ self._arena = arena
+ self._walker = walker
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+
+ arena_position = distributions.Uniform(
+ low=-np.array(arena.size) / 2, high=np.array(arena.size) / 2)
+ if target_spawn_position is not None:
+ self._target_spawn_position = target_spawn_position
+ else:
+ self._target_spawn_position = arena_position
+
+ if walker_spawn_position is not None:
+ self._walker_spawn_position = walker_spawn_position
+ else:
+ self._walker_spawn_position = arena_position
+
+ self._walker_spawn_rotation = walker_spawn_rotation
+
+ self._distance_tolerance = distance_tolerance
+ self._moving_target = moving_target
+ self._target_relative = target_relative
+ self._target_relative_dist = target_relative_dist
+ self._steps_before_moving_target = steps_before_moving_target
+ self._reward_step_counter = 0
+
+ self._target = self.root_entity.mjcf_model.worldbody.add(
+ 'site',
+ name='target',
+ type='sphere',
+ pos=(0., 0., 0.),
+ size=(0.1,),
+ rgba=(0.9, 0.6, 0.6, 1.0))
+
+ enabled_observables = []
+ enabled_observables += self._walker.observables.proprioception
+ enabled_observables += self._walker.observables.kinematic_sensors
+ enabled_observables += self._walker.observables.dynamic_sensors
+ enabled_observables.append(self._walker.observables.sensors_touch)
+ for obs in enabled_observables:
+ obs.enabled = True
+
+ walker.observables.add_egocentric_vector(
+ 'target',
+ observable.MJCFFeature('pos', self._target),
+ origin_callable=lambda physics: physics.bind(walker.root_body).xpos)
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def target_position(self, physics):
+ return np.array(physics.bind(self._target).pos)
+
+ def initialize_episode_mjcf(self, random_state):
+ self._arena.regenerate(random_state=random_state)
+
+ target_x, target_y = variation.evaluate(
+ self._target_spawn_position, random_state=random_state)
+ self._target.pos = [target_x, target_y, 0.]
+
+ def initialize_episode(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+ if self._walker_spawn_rotation:
+ rotation = variation.evaluate(
+ self._walker_spawn_rotation, random_state=random_state)
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ else:
+ quat = None
+ walker_x, walker_y = variation.evaluate(
+ self._walker_spawn_position, random_state=random_state)
+ self._walker.shift_pose(
+ physics,
+ position=[walker_x, walker_y, 0.],
+ quaternion=quat,
+ rotate_velocity=True)
+
+ self._failure_termination = False
+ walker_foot_geoms = set(self._walker.ground_contact_geoms)
+ walker_nonfoot_geoms = [
+ geom for geom in self._walker.mjcf_model.find_all('geom')
+ if geom not in walker_foot_geoms]
+ self._walker_nonfoot_geomids = set(
+ physics.bind(walker_nonfoot_geoms).element_id)
+ self._ground_geomids = set(
+ physics.bind(self._arena.ground_geoms).element_id)
+ self._ground_geomids.add(physics.bind(self._target).element_id)
+
+ def _is_disallowed_contact(self, contact):
+ set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids
+ return ((contact.geom1 in set1 and contact.geom2 in set2) or
+ (contact.geom1 in set2 and contact.geom2 in set1))
+
+ def should_terminate_episode(self, physics):
+ return self._failure_termination
+
+ def get_discount(self, physics):
+ if self._failure_termination:
+ return 0.
+ else:
+ return 1.
+
+ def get_reward(self, physics):
+ reward = 0.
+ distance = np.linalg.norm(
+ physics.bind(self._target).pos[:2] -
+ physics.bind(self._walker.root_body).xpos[:2])
+ if distance < self._distance_tolerance:
+ reward = 1.
+ if self._moving_target:
+ self._reward_step_counter += 1
+ return reward
+
+ def before_step(self, physics, action, random_state):
+ self._walker.apply_action(physics, action, random_state)
+
+ def after_step(self, physics, random_state):
+ self._failure_termination = False
+ for contact in physics.data.contact:
+ if self._is_disallowed_contact(contact):
+ self._failure_termination = True
+ break
+ if (self._moving_target and
+ self._reward_step_counter >= self._steps_before_moving_target):
+
+ # Reset the target position.
+ if self._target_relative:
+ walker_pos = physics.bind(self._walker.root_body).xpos[:2]
+ target_x, target_y = random_state.uniform(
+ -np.array([self._target_relative_dist, self._target_relative_dist]),
+ np.array([self._target_relative_dist, self._target_relative_dist]))
+ target_x += walker_pos[0]
+ target_y += walker_pos[1]
+ else:
+ target_x, target_y = variation.evaluate(
+ self._target_spawn_position, random_state=random_state)
+ physics.bind(self._target).pos = [target_x, target_y, 0.]
+
+ # Reset the number of steps at the target for the moving target.
+ self._reward_step_counter = 0
diff --git a/dm_control/locomotion/tasks/go_to_target_test.py b/dm_control/locomotion/tasks/go_to_target_test.py
new file mode 100644
index 00000000..f88052ec
--- /dev/null
+++ b/dm_control/locomotion/tasks/go_to_target_test.py
@@ -0,0 +1,156 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.tasks.go_to_target."""
+
+
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.tasks import go_to_target
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+
+
+class GoToTargetTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(
+ walker=walker, arena=arena, moving_target=False)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/target', timestep.observation)
+
+ def test_target_position_randomized_on_reset(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(
+ walker=walker, arena=arena, moving_target=False)
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+ first_target_position = task.target_position(env.physics)
+ env.reset()
+ second_target_position = task.target_position(env.physics)
+ self.assertFalse(np.all(first_target_position == second_target_position),
+ 'Target positions are unexpectedly identical.')
+
+ def test_reward_fixed_target(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(
+ walker=walker, arena=arena, moving_target=False)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ target_position = task.target_position(env.physics)
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+ for _ in range(2):
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 0)
+ walker_pos = env.physics.bind(walker.root_body).xpos
+ walker.set_pose(
+ env.physics,
+ position=[target_position[0], target_position[1], walker_pos[2]])
+ env.physics.forward()
+
+ # Receive reward while the agent remains at that location.
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 1)
+
+ # Target position should not change.
+ np.testing.assert_array_equal(target_position,
+ task.target_position(env.physics))
+
+ def test_reward_moving_target(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+
+ steps_before_moving_target = 2
+ task = go_to_target.GoToTarget(
+ walker=walker,
+ arena=arena,
+ moving_target=True,
+ steps_before_moving_target=steps_before_moving_target)
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ target_position = task.target_position(env.physics)
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+ for _ in range(2):
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 0)
+
+ walker_pos = env.physics.bind(walker.root_body).xpos
+ walker.set_pose(
+ env.physics,
+ position=[target_position[0], target_position[1], walker_pos[2]])
+ env.physics.forward()
+
+ # Receive reward while the agent remains at that location.
+ for _ in range(steps_before_moving_target):
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 1)
+ np.testing.assert_array_equal(target_position,
+ task.target_position(env.physics))
+
+ # After taking > steps_before_moving_target, the target should move and
+ # reward should be 0.
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 0)
+
+ def test_termination_and_discount(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(walker=walker, arena=arena)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 1)
+
+ # Rotate the walker upside down and run the physics until it makes contact.
+ current_time = env.physics.data.time
+ walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0))
+ env.physics.forward()
+ while env.physics.data.ncon == 0:
+ env.physics.step()
+ env.physics.data.time = current_time
+
+ # Should now trigger a failure termination.
+ env.step(zero_action)
+ self.assertTrue(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/tasks/random_goal_maze.py b/dm_control/locomotion/tasks/random_goal_maze.py
new file mode 100644
index 00000000..bf9a0084
--- /dev/null
+++ b/dm_control/locomotion/tasks/random_goal_maze.py
@@ -0,0 +1,549 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A task consisting of finding goals/targets in a random maze."""
+
+import collections
+import itertools
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable as observable_lib
+from dm_control.locomotion.props import target_sphere
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+_NUM_RAYS = 10
+
+# Aliveness in [-1., 0.].
+DEFAULT_ALIVE_THRESHOLD = -0.5
+
+DEFAULT_PHYSICS_TIMESTEP = 0.001
+DEFAULT_CONTROL_TIMESTEP = 0.025
+
+
+class NullGoalMaze(composer.Task):
+ """A base task for maze with goals."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ enable_global_task_observables=False,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ """Initializes goal-directed maze task.
+
+ Args:
+ walker: The body to navigate the maze.
+ maze_arena: The physical maze arena object.
+ randomize_spawn_position: Flag to randomize position of spawning.
+ randomize_spawn_rotation: Flag to randomize orientation of spawning.
+ rotation_bias_factor: A non-negative number that concentrates initial
+ orientation away from walls. When set to zero, the initial orientation
+ is uniformly random. The larger the value of this number, the more
+ likely it is that the initial orientation would face the direction that
+ is farthest away from a wall.
+ aliveness_reward: Reward for being alive.
+ aliveness_threshold: Threshold if should terminate based on walker
+ aliveness feature.
+ contact_termination: whether to terminate if a non-foot geom touches the
+ ground.
+ enable_global_task_observables: Flag to provide task observables that
+ contain global information, including map layout.
+ physics_timestep: timestep of simulation.
+ control_timestep: timestep at which agent changes action.
+ """
+ self._walker = walker
+ self._maze_arena = maze_arena
+ self._walker.create_root_joints(self._maze_arena.attach(self._walker))
+
+ self._randomize_spawn_position = randomize_spawn_position
+ self._randomize_spawn_rotation = randomize_spawn_rotation
+ self._rotation_bias_factor = rotation_bias_factor
+
+ self._aliveness_reward = aliveness_reward
+ self._aliveness_threshold = aliveness_threshold
+ self._contact_termination = contact_termination
+ self._discount = 1.0
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ self._walker.observables.egocentric_camera.height = 64
+ self._walker.observables.egocentric_camera.width = 64
+
+ for observable in (self._walker.observables.proprioception +
+ self._walker.observables.kinematic_sensors +
+ self._walker.observables.dynamic_sensors):
+ observable.enabled = True
+ self._walker.observables.egocentric_camera.enabled = True
+
+ if enable_global_task_observables:
+ # Reveal maze text map as observable.
+ maze_obs = observable_lib.Generic(
+ lambda _: self._maze_arena.maze.entity_layer)
+ maze_obs.enabled = True
+
+ # absolute walker position
+ def get_walker_pos(physics):
+ walker_pos = physics.bind(self._walker.root_body).xpos
+ return walker_pos
+ absolute_position = observable_lib.Generic(get_walker_pos)
+ absolute_position.enabled = True
+
+ # absolute walker orientation
+ def get_walker_ori(physics):
+ walker_ori = np.reshape(
+ physics.bind(self._walker.root_body).xmat, (3, 3))
+ return walker_ori
+ absolute_orientation = observable_lib.Generic(get_walker_ori)
+ absolute_orientation.enabled = True
+
+ # grid element of player in maze cell: i,j cell in maze layout
+ def get_walker_ij(physics):
+ walker_xypos = physics.bind(self._walker.root_body).xpos[:-1]
+ walker_rel_origin = (
+ (walker_xypos +
+ np.sign(walker_xypos) * self._maze_arena.xy_scale / 2) /
+ (self._maze_arena.xy_scale)).astype(int)
+ x_offset = (self._maze_arena.maze.width - 1) / 2
+ y_offset = (self._maze_arena.maze.height - 1) / 2
+ walker_ij = walker_rel_origin + np.array([x_offset, y_offset])
+ return walker_ij
+ absolute_position_discrete = observable_lib.Generic(get_walker_ij)
+ absolute_position_discrete.enabled = True
+
+ self._task_observables = collections.OrderedDict({
+ 'maze_layout': maze_obs,
+ 'absolute_position': absolute_position,
+ 'absolute_orientation': absolute_orientation,
+ 'location_in_maze': absolute_position_discrete, # from bottom left
+ })
+ else:
+ self._task_observables = collections.OrderedDict({})
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ @property
+ def name(self):
+ return 'goal_maze'
+
+ @property
+ def root_entity(self):
+ return self._maze_arena
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._maze_arena.regenerate()
+
+ def _respawn(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+
+ if self._randomize_spawn_position:
+ self._spawn_position = self._maze_arena.spawn_positions[
+ random_state.randint(0, len(self._maze_arena.spawn_positions))]
+ else:
+ self._spawn_position = self._maze_arena.spawn_positions[0]
+
+ if self._randomize_spawn_rotation:
+ # Move walker up out of the way before raycasting.
+ self._walker.shift_pose(physics, [0.0, 0.0, 100.0])
+
+ distances = []
+ geomid_out = np.array([-1], dtype=np.intc)
+ for i in range(_NUM_RAYS):
+ theta = 2 * np.pi * i / _NUM_RAYS
+ pos = np.array([self._spawn_position[0], self._spawn_position[1], 0.1],
+ dtype=np.float64)
+ vec = np.array([np.cos(theta), np.sin(theta), 0], dtype=np.float64)
+ dist = mjbindings.mjlib.mj_ray(
+ physics.model.ptr, physics.data.ptr, pos, vec,
+ None, 1, -1, geomid_out)
+ distances.append(dist)
+
+ def remap_with_bias(x):
+ """Remaps values [-1, 1] -> [-1, 1] with bias."""
+ return np.tanh((1 + self._rotation_bias_factor) * np.arctanh(x))
+
+ max_theta = 2 * np.pi * np.argmax(distances) / _NUM_RAYS
+ rotation = max_theta + np.pi * (
+ 1 + remap_with_bias(random_state.uniform(-1, 1)))
+
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+
+ # Move walker back down.
+ self._walker.shift_pose(physics, [0.0, 0.0, -100.0])
+ else:
+ quat = None
+
+ self._walker.shift_pose(
+ physics, [self._spawn_position[0], self._spawn_position[1], 0.0],
+ quat,
+ rotate_velocity=True)
+
+ def initialize_episode(self, physics, random_state):
+ super().initialize_episode(physics, random_state)
+ self._respawn(physics, random_state)
+ self._discount = 1.0
+
+ walker_foot_geoms = set(self._walker.ground_contact_geoms)
+ walker_nonfoot_geoms = [
+ geom for geom in self._walker.mjcf_model.find_all('geom')
+ if geom not in walker_foot_geoms]
+ self._walker_nonfoot_geomids = set(
+ physics.bind(walker_nonfoot_geoms).element_id)
+ self._ground_geomids = set(
+ physics.bind(self._maze_arena.ground_geoms).element_id)
+
+ def _is_disallowed_contact(self, contact):
+ set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids
+ return ((contact.geom1 in set1 and contact.geom2 in set2) or
+ (contact.geom1 in set2 and contact.geom2 in set1))
+
+ def after_step(self, physics, random_state):
+ self._failure_termination = False
+ if self._contact_termination:
+ for c in physics.data.contact:
+ if self._is_disallowed_contact(c):
+ self._failure_termination = True
+ break
+
+ def should_terminate_episode(self, physics):
+ if self._walker.aliveness(physics) < self._aliveness_threshold:
+ self._failure_termination = True
+ if self._failure_termination:
+ self._discount = 0.0
+ return True
+ else:
+ return False
+
+ def get_reward(self, physics):
+ del physics
+ return self._aliveness_reward
+
+ def get_discount(self, physics):
+ del physics
+ return self._discount
+
+
+class RepeatSingleGoalMaze(NullGoalMaze):
+ """Requires an agent to repeatedly find the same goal in a maze."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ target=None,
+ target_reward_scale=1.0,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ max_repeats=0,
+ enable_global_task_observables=False,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP,
+ regenerate_maze_on_repeat=False):
+ super().__init__(
+ walker=walker,
+ maze_arena=maze_arena,
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ enable_global_task_observables=enable_global_task_observables,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+ if target is None:
+ target = target_sphere.TargetSphere()
+ self._target = target
+ self._rewarded_this_step = False
+ self._maze_arena.attach(target)
+ self._target_reward_scale = target_reward_scale
+ self._max_repeats = max_repeats
+ self._targets_obtained = 0
+ self._regenerate_maze_on_repeat = regenerate_maze_on_repeat
+
+ if enable_global_task_observables:
+ xpos_origin_callable = lambda phys: phys.bind(walker.root_body).xpos
+
+ def _target_pos(physics, target=target):
+ return physics.bind(target.geom).xpos
+
+ walker.observables.add_egocentric_vector(
+ 'target_0',
+ observable_lib.Generic(_target_pos),
+ origin_callable=xpos_origin_callable)
+
+ def initialize_episode_mjcf(self, random_state):
+ super().initialize_episode_mjcf(random_state)
+ self._target_position = self._maze_arena.target_positions[
+ random_state.randint(0, len(self._maze_arena.target_positions))]
+ mjcf.get_attachment_frame(
+ self._target.mjcf_model).pos = self._target_position
+
+ def initialize_episode(self, physics, random_state):
+ super().initialize_episode(physics, random_state)
+ self._rewarded_this_step = False
+ self._targets_obtained = 0
+
+ def after_step(self, physics, random_state):
+ super().after_step(physics, random_state)
+ if self._target.activated:
+ self._rewarded_this_step = True
+ self._targets_obtained += 1
+ if self._targets_obtained <= self._max_repeats:
+ if self._regenerate_maze_on_repeat:
+ self.initialize_episode_mjcf(random_state)
+ self._target.set_pose(physics, self._target_position)
+ self._respawn(physics, random_state)
+ self._target.reset(physics)
+ else:
+ self._rewarded_this_step = False
+
+ def should_terminate_episode(self, physics):
+ if super().should_terminate_episode(physics):
+ return True
+ if self._targets_obtained > self._max_repeats:
+ return True
+
+ def get_reward(self, physics):
+ del physics
+ if self._rewarded_this_step:
+ target_reward = self._target_reward_scale
+ else:
+ target_reward = 0.0
+ return target_reward + self._aliveness_reward
+
+
+class ManyHeterogeneousGoalsMaze(NullGoalMaze):
+ """Requires an agent to find multiple goals with different rewards."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ target_builders,
+ target_type_rewards,
+ target_type_proportions,
+ shuffle_target_builders=False,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ super().__init__(
+ walker=walker,
+ maze_arena=maze_arena,
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+ self._active_targets = []
+ self._target_builders = target_builders
+ self._target_type_rewards = tuple(target_type_rewards)
+ self._target_type_fractions = (
+ np.array(target_type_proportions, dtype=float) /
+ np.sum(target_type_proportions))
+ self._shuffle_target_builders = shuffle_target_builders
+
+ def _get_targets(self, total_target_count, random_state):
+ # Multiply total target count by the fraction for each type, rounded down.
+ target_numbers = np.array([int(frac * total_target_count)
+ for frac in self._target_type_fractions])
+
+ # Calculate deviations from the ideal ratio incurred by rounding.
+ errors = (self._target_type_fractions -
+ target_numbers / float(total_target_count))
+
+ # Sort the target types by deviations from ideal ratios.
+ target_types_sorted_by_errors = list(np.argsort(errors))
+
+ # Top up individual target classes until we reach the desired total,
+ # starting from the class that is furthest away from the ideal ratio.
+ current_total = np.sum(target_numbers)
+ while current_total < total_target_count:
+ target_numbers[target_types_sorted_by_errors.pop()] += 1
+ current_total += 1
+
+ if self._shuffle_target_builders:
+ random_state.shuffle(self._target_builders)
+
+ all_targets = []
+ for target_type, num in enumerate(target_numbers):
+ targets = []
+ target_builder = self._target_builders[target_type]
+ for i in range(num):
+ target = target_builder(name='target_{}_{}'.format(target_type, i))
+ targets.append(target)
+ all_targets.append(targets)
+ return all_targets
+
+ def initialize_episode_mjcf(self, random_state):
+ super(
+ ManyHeterogeneousGoalsMaze, self).initialize_episode_mjcf(random_state)
+ for target in itertools.chain(*self._active_targets):
+ target.detach()
+ target_positions = list(self._maze_arena.target_positions)
+ random_state.shuffle(target_positions)
+ all_targets = self._get_targets(len(target_positions), random_state)
+ for pos, target in zip(target_positions, itertools.chain(*all_targets)):
+ self._maze_arena.attach(target)
+ mjcf.get_attachment_frame(target.mjcf_model).pos = pos
+ target.initialize_episode_mjcf(random_state)
+ self._active_targets = all_targets
+ self._target_rewarded = [[False] * len(targets) for targets in all_targets]
+
+ def get_reward(self, physics):
+ del physics
+ reward = self._aliveness_reward
+ for target_type, targets in enumerate(self._active_targets):
+ for i, target in enumerate(targets):
+ if target.activated and not self._target_rewarded[target_type][i]:
+ reward += self._target_type_rewards[target_type]
+ self._target_rewarded[target_type][i] = True
+ return reward
+
+ def should_terminate_episode(self, physics):
+ if super(ManyHeterogeneousGoalsMaze,
+ self).should_terminate_episode(physics):
+ return True
+ else:
+ for target in itertools.chain(*self._active_targets):
+ if not target.activated:
+ return False
+ # All targets have been activated: successful termination.
+ return True
+
+
+class ManyGoalsMaze(ManyHeterogeneousGoalsMaze):
+ """Requires an agent to find all goals in a random maze."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ target_builder,
+ target_reward_scale=1.0,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ super().__init__(
+ walker=walker,
+ maze_arena=maze_arena,
+ target_builders=[target_builder],
+ target_type_rewards=[target_reward_scale],
+ target_type_proportions=[1],
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+
+
+class RepeatSingleGoalMazeAugmentedWithTargets(RepeatSingleGoalMaze):
+ """Augments the single goal maze with many lower reward targets."""
+
+ def __init__(self,
+ walker,
+ main_target,
+ maze_arena,
+ num_subtargets=20,
+ target_reward_scale=10.0,
+ subtarget_reward_scale=1.0,
+ subtarget_colors=((0, 0, 0.4), (0, 0, 0.7)),
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ super().__init__(
+ walker=walker,
+ target=main_target,
+ maze_arena=maze_arena,
+ target_reward_scale=target_reward_scale,
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+ self._subtarget_reward_scale = subtarget_reward_scale
+ self._subtargets = []
+ for i in range(num_subtargets):
+ subtarget = target_sphere.TargetSphere(
+ radius=0.4, rgb1=subtarget_colors[0], rgb2=subtarget_colors[1],
+ name='subtarget_{}'.format(i)
+ )
+ self._subtargets.append(subtarget)
+ self._maze_arena.attach(subtarget)
+ self._subtarget_rewarded = None
+
+ def initialize_episode_mjcf(self, random_state):
+ super(RepeatSingleGoalMazeAugmentedWithTargets,
+ self).initialize_episode_mjcf(random_state)
+ subtarget_positions = self._maze_arena.target_positions
+ for pos, subtarget in zip(subtarget_positions, self._subtargets):
+ mjcf.get_attachment_frame(subtarget.mjcf_model).pos = pos
+ self._subtarget_rewarded = [False] * len(self._subtargets)
+
+ def get_reward(self, physics):
+ main_reward = super(RepeatSingleGoalMazeAugmentedWithTargets,
+ self).get_reward(physics)
+ subtarget_reward = 0
+ for i, subtarget in enumerate(self._subtargets):
+ if subtarget.activated and not self._subtarget_rewarded[i]:
+ subtarget_reward += 1
+ self._subtarget_rewarded[i] = True
+ subtarget_reward *= self._subtarget_reward_scale
+ return main_reward + subtarget_reward
+
+ def should_terminate_episode(self, physics):
+ if super(RepeatSingleGoalMazeAugmentedWithTargets,
+ self).should_terminate_episode(physics):
+ return True
+ else:
+ for subtarget in self._subtargets:
+ if not subtarget.activated:
+ return False
+ # All subtargets have been activated.
+ return True
diff --git a/dm_control/locomotion/tasks/random_goal_maze_test.py b/dm_control/locomotion/tasks/random_goal_maze_test.py
new file mode 100644
index 00000000..999c1454
--- /dev/null
+++ b/dm_control/locomotion/tasks/random_goal_maze_test.py
@@ -0,0 +1,129 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.tasks.random_goal_maze."""
+
+import functools
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import random_goal_maze
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+
+
+class RandomGoalMazeTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = cmu_humanoid.CMUHumanoid()
+
+ # Build a maze with rooms and targets.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ control_timestep=.03,
+ physics_timestep=.005,
+ )
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/joints_pos', timestep.observation)
+
+ def test_termination_and_discount(self):
+ walker = cmu_humanoid.CMUHumanoid()
+
+ # Build a maze with rooms and targets.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ control_timestep=.03,
+ physics_timestep=.005,
+ )
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 1)
+
+ # Rotate the walker upside down and run the physics until it makes contact.
+ current_time = env.physics.data.time
+ walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0))
+ env.physics.forward()
+ while env.physics.data.ncon == 0:
+ env.physics.step()
+ env.physics.data.time = current_time
+
+ # Should now trigger a failure termination.
+ env.step(zero_action)
+ self.assertTrue(task.should_terminate_episode(env.physics))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/tasks/reach.py b/dm_control/locomotion/tasks/reach.py
new file mode 100644
index 00000000..1a8cb21b
--- /dev/null
+++ b/dm_control/locomotion/tasks/reach.py
@@ -0,0 +1,287 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A (visuomotor) task consisting of reaching to targets for reward."""
+
+import collections
+import enum
+import itertools
+
+from dm_control import composer
+from dm_control.composer.observation import observable as dm_observable
+import numpy as np
+
+DEFAULT_ALIVE_THRESHOLD = -1.0
+DEFAULT_PHYSICS_TIMESTEP = 0.005
+DEFAULT_CONTROL_TIMESTEP = 0.03
+
+
+class TwoTouchState(enum.IntEnum):
+ PRE_TOUCH = 0
+ TOUCHED_ONCE = 1
+ TOUCHED_TWICE = 2 # at appropriate time
+ TOUCHED_TOO_SOON = 3
+ NO_SECOND_TOUCH = 4
+
+
+class TwoTouch(composer.Task):
+ """Task with target to tap with short delay (for Rat)."""
+
+ def __init__(self,
+ walker,
+ arena,
+ target_builders,
+ target_type_rewards,
+ shuffle_target_builders=False,
+ randomize_spawn_position=False,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ touch_interval=0.8,
+ interval_tolerance=0.1, # consider making a curriculum
+ failure_timeout=1.2,
+ reset_delay=0.,
+ z_height=.14, # 5.5" in real experiments
+ target_area=(),
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ self._walker = walker
+ self._arena = arena
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+ if 'CMUHumanoid' in str(type(self._walker)):
+ self._lhand_body = walker.mjcf_model.find('body', 'lhand')
+ self._rhand_body = walker.mjcf_model.find('body', 'rhand')
+ elif 'Rat' in str(type(self._walker)):
+ self._lhand_body = walker.mjcf_model.find('body', 'hand_L')
+ self._rhand_body = walker.mjcf_model.find('body', 'hand_R')
+ else:
+ raise ValueError('Expects Rat or CMUHumanoid.')
+ self._lhand_geoms = self._lhand_body.find_all('geom')
+ self._rhand_geoms = self._rhand_body.find_all('geom')
+
+ self._targets = []
+ self._target_builders = target_builders
+ self._target_type_rewards = tuple(target_type_rewards)
+ self._shuffle_target_builders = shuffle_target_builders
+
+ self._randomize_spawn_position = randomize_spawn_position
+ self._spawn_position = [0.0, 0.0] # x, y
+ self._randomize_spawn_rotation = randomize_spawn_rotation
+ self._rotation_bias_factor = rotation_bias_factor
+
+ self._aliveness_reward = aliveness_reward
+ self._discount = 1.0
+
+ self._touch_interval = touch_interval
+ self._interval_tolerance = interval_tolerance
+ self._failure_timeout = failure_timeout
+ self._reset_delay = reset_delay
+ self._target_positions = []
+ self._state_logic = TwoTouchState.PRE_TOUCH
+
+ self._z_height = z_height
+ arena_size = self._arena.size
+ if target_area:
+ self._target_area = target_area
+ else:
+ self._target_area = [1/2*arena_size[0], 1/2*arena_size[1]]
+ target_x = 1.
+ target_y = 1.
+ self._target_positions.append((target_x, target_y, self._z_height))
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ self._task_observables = collections.OrderedDict()
+ def task_state(physics):
+ del physics
+ return np.array([self._state_logic])
+ self._task_observables['task_logic'] = dm_observable.Generic(task_state)
+
+ self._walker.observables.egocentric_camera.height = 64
+ self._walker.observables.egocentric_camera.width = 64
+
+ for observable in (self._walker.observables.proprioception +
+ self._walker.observables.kinematic_sensors +
+ self._walker.observables.dynamic_sensors +
+ list(self._task_observables.values())):
+ observable.enabled = True
+ self._walker.observables.egocentric_camera.enabled = True
+
+ def _get_targets(self, total_target_count, random_state):
+ # Multiply total target count by the fraction for each type, rounded down.
+ target_numbers = np.array([1, len(self._target_positions)-1])
+
+ if self._shuffle_target_builders:
+ random_state.shuffle(self._target_builders)
+
+ all_targets = []
+ for target_type, num in enumerate(target_numbers):
+ targets = []
+ if num < 1:
+ break
+ target_builder = self._target_builders[target_type]
+ for i in range(num):
+ target = target_builder(name='target_{}_{}'.format(target_type, i))
+ targets.append(target)
+ all_targets.append(targets)
+ return all_targets
+
+ @property
+ def name(self):
+ return 'two_touch'
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def _randomize_targets(self, physics, random_state=np.random):
+ for ii in range(len(self._target_positions)):
+ target_x = self._target_area[0]*random_state.uniform(-1., 1.)
+ target_y = self._target_area[1]*random_state.uniform(-1., 1.)
+ self._target_positions[ii] = (target_x, target_y, self._z_height)
+ target_positions = np.copy(self._target_positions)
+ random_state.shuffle(target_positions)
+ all_targets = self._targets
+ for pos, target in zip(target_positions, itertools.chain(*all_targets)):
+ target.reset(physics)
+ physics.bind(target.geom).pos = pos
+ self._targets = all_targets
+ self._target_rewarded_once = [
+ [False] * len(targets) for targets in all_targets]
+ self._target_rewarded_twice = [
+ [False] * len(targets) for targets in all_targets]
+ self._first_touch_time = None
+ self._second_touch_time = None
+ self._do_time_out = False
+ self._state_logic = TwoTouchState.PRE_TOUCH
+
+ def initialize_episode_mjcf(self, random_state):
+ self._arena.regenerate(random_state)
+ for target in itertools.chain(*self._targets):
+ target.detach()
+ target_positions = np.copy(self._target_positions)
+ random_state.shuffle(target_positions)
+ all_targets = self._get_targets(len(self._target_positions), random_state)
+ for pos, target in zip(target_positions, itertools.chain(*all_targets)):
+ self._arena.attach(target)
+ target.geom.pos = pos
+ target.initialize_episode_mjcf(random_state)
+ self._targets = all_targets
+
+ def _respawn_walker(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+
+ if self._randomize_spawn_position:
+ self._spawn_position = self._arena.spawn_positions[
+ random_state.randint(0, len(self._arena.spawn_positions))]
+
+ if self._randomize_spawn_rotation:
+ rotation = 2*np.pi*np.random.uniform()
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+
+ self._walker.shift_pose(
+ physics,
+ [self._spawn_position[0], self._spawn_position[1], 0.0],
+ quat,
+ rotate_velocity=True)
+
+ def initialize_episode(self, physics, random_state):
+ super().initialize_episode(physics, random_state)
+ self._respawn_walker(physics, random_state)
+ self._state_logic = TwoTouchState.PRE_TOUCH
+ self._discount = 1.0
+ self._lhand_geomids = set(physics.bind(self._lhand_geoms).element_id)
+ self._rhand_geomids = set(physics.bind(self._rhand_geoms).element_id)
+ self._hand_geomids = self._lhand_geomids | self._rhand_geomids
+ self._randomize_targets(physics)
+ self._must_randomize_targets = False
+ for target in itertools.chain(*self._targets):
+ target._specific_collision_geom_ids = self._hand_geomids # pylint: disable=protected-access
+
+ def before_step(self, physics, action, random_state):
+ super().before_step(physics, action, random_state)
+ if self._must_randomize_targets:
+ self._randomize_targets(physics)
+ self._must_randomize_targets = False
+
+ def should_terminate_episode(self, physics):
+ failure_termination = False
+ if failure_termination:
+ self._discount = 0.0
+ return True
+ else:
+ return False
+
+ def get_reward(self, physics):
+ reward = self._aliveness_reward
+ lhand_pos = physics.bind(self._lhand_body).xpos
+ rhand_pos = physics.bind(self._rhand_body).xpos
+ target_pos = physics.bind(self._targets[0][0].geom).xpos
+ lhand_rew = np.exp(-3.*sum(np.abs(lhand_pos-target_pos)))
+ rhand_rew = np.exp(-3.*sum(np.abs(rhand_pos-target_pos)))
+ closeness_reward = np.maximum(lhand_rew, rhand_rew)
+ reward += .01*closeness_reward*self._target_type_rewards[0]
+ if self._state_logic == TwoTouchState.PRE_TOUCH:
+ # touch the first time
+ for target_type, targets in enumerate(self._targets):
+ for i, target in enumerate(targets):
+ if (target.activated[0] and
+ not self._target_rewarded_once[target_type][i]):
+ self._first_touch_time = physics.time()
+ self._state_logic = TwoTouchState.TOUCHED_ONCE
+ self._target_rewarded_once[target_type][i] = True
+ reward += self._target_type_rewards[target_type]
+ elif self._state_logic == TwoTouchState.TOUCHED_ONCE:
+ for target_type, targets in enumerate(self._targets):
+ for i, target in enumerate(targets):
+ if (target.activated[1] and
+ not self._target_rewarded_twice[target_type][i]):
+ self._second_touch_time = physics.time()
+ self._state_logic = TwoTouchState.TOUCHED_TWICE
+ self._target_rewarded_twice[target_type][i] = True
+ # check if touched too soon
+ if ((self._second_touch_time - self._first_touch_time) <
+ (self._touch_interval - self._interval_tolerance)):
+ self._do_time_out = True
+ self._state_logic = TwoTouchState.TOUCHED_TOO_SOON
+ # check if touched at correct time
+ elif ((self._second_touch_time - self._first_touch_time) <=
+ (self._touch_interval + self._interval_tolerance)):
+ reward += self._target_type_rewards[target_type]
+ # check if no second touch within time interval
+ if ((physics.time() - self._first_touch_time) >
+ (self._touch_interval + self._interval_tolerance)):
+ self._do_time_out = True
+ self._state_logic = TwoTouchState.NO_SECOND_TOUCH
+ self._second_touch_time = physics.time()
+ elif (self._state_logic == TwoTouchState.TOUCHED_TWICE or
+ self._state_logic == TwoTouchState.TOUCHED_TOO_SOON or
+ self._state_logic == TwoTouchState.NO_SECOND_TOUCH):
+ # hold here due to timeout
+ if self._do_time_out:
+ if physics.time() > (self._second_touch_time + self._failure_timeout):
+ self._do_time_out = False
+ # reset/re-randomize
+ elif physics.time() > (self._second_touch_time + self._reset_delay):
+ self._must_randomize_targets = True
+ return reward
+
+ def get_discount(self, physics):
+ del physics
+ return self._discount
diff --git a/dm_control/locomotion/tasks/reach_test.py b/dm_control/locomotion/tasks/reach_test.py
new file mode 100644
index 00000000..5fa90e2f
--- /dev/null
+++ b/dm_control/locomotion/tasks/reach_test.py
@@ -0,0 +1,61 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.tasks.reach."""
+
+import functools
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import reach
+from dm_control.locomotion.walkers import rodent
+import numpy as np
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+class ReachTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = rodent.Rat()
+
+ arena = floors.Floor(
+ size=(10., 10.),
+ aesthetic='outdoor_natural')
+
+ task = reach.TwoTouch(
+ walker=walker,
+ arena=arena,
+ target_builders=[
+ functools.partial(target_sphere.TargetSphereTwoTouch, radius=0.025),
+ ],
+ randomize_spawn_rotation=True,
+ target_type_rewards=[25.],
+ shuffle_target_builders=False,
+ target_area=(1.5, 1.5),
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP,
+ )
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/joints_pos', timestep.observation)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/tasks/reference_pose/README.md b/dm_control/locomotion/tasks/reference_pose/README.md
new file mode 100644
index 00000000..741d6ac8
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/README.md
@@ -0,0 +1,16 @@
+# Reference pose tasks
+
+This directory contains components to define tasks based on reference poses (e.g
+motion capture data) as well as a motion capture tracking tasks. The tasks and
+associated utils were developed as part of
+[CoMic: Complementary Task Learning & Mimicry for Reusable Skills (2020)][hasenclever2020].
+
+The reference data is stored in HDF5 files, which can be loaded using the
+`HDF5TrajectoryLoader` class in `dm_control/locomotion/mocap/loader.py`. To
+download the data used in the CoMic project, please use
+`dm_control/locomotion/mocap/cmu_mocap_data.py`. In the reference pose tasks,
+reference trajectories are represented as `Trajectory` objects (see
+`dm_control/locomotion/mocap/trajectory.py`). For an example of how to construct
+a task, see `dm_control/locomotion/examples/cmu_2020_tracking.py`.
+
+[hasenclever2020]: https://proceedings.icml.cc/static/paper_files/icml/2020/5013-Paper.pdf
diff --git a/dm_control/locomotion/tasks/reference_pose/__init__.py b/dm_control/locomotion/tasks/reference_pose/__init__.py
new file mode 100644
index 00000000..fc649923
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Reference pose tasks in the Locomotion library."""
diff --git a/dm_control/locomotion/tasks/reference_pose/cmu_subsets.py b/dm_control/locomotion/tasks/reference_pose/cmu_subsets.py
new file mode 100644
index 00000000..b10eea35
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/cmu_subsets.py
@@ -0,0 +1,1289 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Subsets of the CMU mocap database."""
+
+from dm_control.locomotion.tasks.reference_pose import types
+
+ClipCollection = types.ClipCollection
+
+# get up
+GET_UP = ClipCollection(
+ ids=('CMU_139_16',
+ 'CMU_139_17',
+ 'CMU_139_18',
+ 'CMU_140_01',
+ 'CMU_140_02',
+ 'CMU_140_08',
+ 'CMU_140_09')
+)
+
+# Subset of about 40 minutes of varied locomotion behaviors.
+LOCOMOTION_SMALL = ClipCollection(
+ ids=('CMU_001_01',
+ 'CMU_002_03',
+ 'CMU_002_04',
+ 'CMU_009_01',
+ 'CMU_009_02',
+ 'CMU_009_03',
+ 'CMU_009_04',
+ 'CMU_009_05',
+ 'CMU_009_06',
+ 'CMU_009_07',
+ 'CMU_009_08',
+ 'CMU_009_09',
+ 'CMU_009_10',
+ 'CMU_009_11',
+ 'CMU_013_11',
+ 'CMU_013_13',
+ 'CMU_013_19',
+ 'CMU_013_32',
+ 'CMU_013_39',
+ 'CMU_013_40',
+ 'CMU_013_41',
+ 'CMU_013_42',
+ 'CMU_014_07',
+ 'CMU_014_08',
+ 'CMU_014_09',
+ 'CMU_016_01',
+ 'CMU_016_02',
+ 'CMU_016_03',
+ 'CMU_016_04',
+ 'CMU_016_05',
+ 'CMU_016_06',
+ 'CMU_016_07',
+ 'CMU_016_08',
+ 'CMU_016_09',
+ 'CMU_016_10',
+ 'CMU_016_17',
+ 'CMU_016_18',
+ 'CMU_016_19',
+ 'CMU_016_20',
+ 'CMU_016_27',
+ 'CMU_016_28',
+ 'CMU_016_29',
+ 'CMU_016_30',
+ 'CMU_016_35',
+ 'CMU_016_36',
+ 'CMU_016_37',
+ 'CMU_016_38',
+ 'CMU_016_39',
+ 'CMU_016_40',
+ 'CMU_016_41',
+ 'CMU_016_42',
+ 'CMU_016_43',
+ 'CMU_016_44',
+ 'CMU_016_45',
+ 'CMU_016_46',
+ 'CMU_016_48',
+ 'CMU_016_49',
+ 'CMU_016_50',
+ 'CMU_016_51',
+ 'CMU_016_52',
+ 'CMU_016_53',
+ 'CMU_016_54',
+ 'CMU_016_55',
+ 'CMU_016_56',
+ 'CMU_016_57',
+ 'CMU_035_17',
+ 'CMU_035_18',
+ 'CMU_035_19',
+ 'CMU_035_20',
+ 'CMU_035_21',
+ 'CMU_035_22',
+ 'CMU_035_23',
+ 'CMU_035_24',
+ 'CMU_035_25',
+ 'CMU_035_26',
+ 'CMU_036_02',
+ 'CMU_036_03',
+ 'CMU_036_09',
+ 'CMU_038_03',
+ 'CMU_038_04',
+ 'CMU_039_11',
+ 'CMU_047_01',
+ 'CMU_049_02',
+ 'CMU_049_03',
+ 'CMU_049_04',
+ 'CMU_049_05',
+ 'CMU_069_06',
+ 'CMU_069_07',
+ 'CMU_069_08',
+ 'CMU_069_09',
+ 'CMU_069_10',
+ 'CMU_069_11',
+ 'CMU_069_12',
+ 'CMU_069_13',
+ 'CMU_069_14',
+ 'CMU_069_15',
+ 'CMU_069_16',
+ 'CMU_069_17',
+ 'CMU_069_18',
+ 'CMU_069_19',
+ 'CMU_069_20',
+ 'CMU_069_21',
+ 'CMU_069_22',
+ 'CMU_069_23',
+ 'CMU_069_24',
+ 'CMU_069_25',
+ 'CMU_069_26',
+ 'CMU_069_27',
+ 'CMU_069_28',
+ 'CMU_069_29',
+ 'CMU_069_30',
+ 'CMU_069_31',
+ 'CMU_069_32',
+ 'CMU_069_33',
+ 'CMU_069_42',
+ 'CMU_069_43',
+ 'CMU_069_44',
+ 'CMU_069_45',
+ 'CMU_069_46',
+ 'CMU_069_47',
+ 'CMU_069_48',
+ 'CMU_069_49',
+ 'CMU_069_56',
+ 'CMU_069_57',
+ 'CMU_069_58',
+ 'CMU_069_59',
+ 'CMU_069_60',
+ 'CMU_069_61',
+ 'CMU_069_62',
+ 'CMU_069_63',
+ 'CMU_069_64',
+ 'CMU_069_65',
+ 'CMU_069_66',
+ 'CMU_069_67',
+ 'CMU_075_01',
+ 'CMU_075_02',
+ 'CMU_075_03',
+ 'CMU_075_04',
+ 'CMU_075_05',
+ 'CMU_075_06',
+ 'CMU_075_07',
+ 'CMU_075_08',
+ 'CMU_075_09',
+ 'CMU_075_10',
+ 'CMU_075_11',
+ 'CMU_075_12',
+ 'CMU_075_13',
+ 'CMU_075_14',
+ 'CMU_075_15',
+ 'CMU_076_10',
+ 'CMU_077_10',
+ 'CMU_077_11',
+ 'CMU_077_12',
+ 'CMU_077_13',
+ 'CMU_078_01',
+ 'CMU_078_02',
+ 'CMU_078_03',
+ 'CMU_078_07',
+ 'CMU_078_09',
+ 'CMU_078_10',
+ 'CMU_082_15',
+ 'CMU_083_36',
+ 'CMU_083_37',
+ 'CMU_083_38',
+ 'CMU_083_39',
+ 'CMU_083_40',
+ 'CMU_083_41',
+ 'CMU_083_42',
+ 'CMU_083_43',
+ 'CMU_083_45',
+ 'CMU_083_46',
+ 'CMU_083_48',
+ 'CMU_083_49',
+ 'CMU_083_51',
+ 'CMU_083_52',
+ 'CMU_083_53',
+ 'CMU_083_54',
+ 'CMU_083_56',
+ 'CMU_083_57',
+ 'CMU_083_58',
+ 'CMU_083_59',
+ 'CMU_083_60',
+ 'CMU_083_61',
+ 'CMU_083_62',
+ 'CMU_083_64',
+ 'CMU_083_65',
+ 'CMU_086_01',
+ 'CMU_086_11',
+ 'CMU_090_06',
+ 'CMU_090_07',
+ 'CMU_091_39',
+ 'CMU_091_40',
+ 'CMU_091_41',
+ 'CMU_091_42',
+ 'CMU_091_43',
+ 'CMU_091_44',
+ 'CMU_091_45',
+ 'CMU_091_46',
+ 'CMU_091_47',
+ 'CMU_091_48',
+ 'CMU_091_49',
+ 'CMU_091_50',
+ 'CMU_091_51',
+ 'CMU_091_52',
+ 'CMU_091_53',
+ 'CMU_104_53',
+ 'CMU_104_54',
+ 'CMU_104_55',
+ 'CMU_104_56',
+ 'CMU_104_57',
+ 'CMU_105_39',
+ 'CMU_105_40',
+ 'CMU_105_41',
+ 'CMU_105_42',
+ 'CMU_105_43',
+ 'CMU_105_44',
+ 'CMU_105_45',
+ 'CMU_105_46',
+ 'CMU_105_47',
+ 'CMU_105_48',
+ 'CMU_105_49',
+ 'CMU_105_50',
+ 'CMU_105_51',
+ 'CMU_105_52',
+ 'CMU_118_01',
+ 'CMU_118_02',
+ 'CMU_118_03',
+ 'CMU_118_04',
+ 'CMU_118_05',
+ 'CMU_118_06',
+ 'CMU_118_07',
+ 'CMU_118_08',
+ 'CMU_118_09',
+ 'CMU_118_10',
+ 'CMU_118_11',
+ 'CMU_118_12',
+ 'CMU_118_13',
+ 'CMU_118_14',
+ 'CMU_118_15',
+ 'CMU_118_16',
+ 'CMU_118_17',
+ 'CMU_118_18',
+ 'CMU_118_19',
+ 'CMU_118_20',
+ 'CMU_118_21',
+ 'CMU_118_22',
+ 'CMU_118_23',
+ 'CMU_118_24',
+ 'CMU_118_25',
+ 'CMU_118_26',
+ 'CMU_118_27',
+ 'CMU_118_28',
+ 'CMU_118_29',
+ 'CMU_118_30',
+ 'CMU_127_03',
+ 'CMU_127_04',
+ 'CMU_127_05',
+ 'CMU_127_06',
+ 'CMU_127_07',
+ 'CMU_127_08',
+ 'CMU_127_09',
+ 'CMU_127_10',
+ 'CMU_127_11',
+ 'CMU_127_12',
+ 'CMU_127_13',
+ 'CMU_127_14',
+ 'CMU_127_15',
+ 'CMU_127_16',
+ 'CMU_127_17',
+ 'CMU_127_18',
+ 'CMU_127_19',
+ 'CMU_127_20',
+ 'CMU_127_21',
+ 'CMU_127_22',
+ 'CMU_127_23',
+ 'CMU_127_24',
+ 'CMU_127_25',
+ 'CMU_127_26',
+ 'CMU_127_27',
+ 'CMU_127_28',
+ 'CMU_127_29',
+ 'CMU_127_30',
+ 'CMU_127_31',
+ 'CMU_127_32',
+ 'CMU_127_37',
+ 'CMU_127_38',
+ 'CMU_128_02',
+ 'CMU_128_03',
+ 'CMU_128_04',
+ 'CMU_128_05',
+ 'CMU_128_06',
+ 'CMU_128_07',
+ 'CMU_128_08',
+ 'CMU_128_09',
+ 'CMU_128_10',
+ 'CMU_128_11',
+ 'CMU_132_23',
+ 'CMU_132_24',
+ 'CMU_132_25',
+ 'CMU_132_26',
+ 'CMU_132_27',
+ 'CMU_132_28',
+ 'CMU_139_10',
+ 'CMU_139_11',
+ 'CMU_139_12',
+ 'CMU_139_13',
+ 'CMU_143_01',
+ 'CMU_143_02',
+ 'CMU_143_03',
+ 'CMU_143_04',
+ 'CMU_143_05',
+ 'CMU_143_06',
+ 'CMU_143_07',
+ 'CMU_143_08',
+ 'CMU_143_09',
+ 'CMU_143_42'))
+
+# Subset of about 2 minutes of walking behaviors.
+WALK_TINY = ClipCollection(
+ ids=('CMU_016_22',
+ 'CMU_016_23',
+ 'CMU_016_24',
+ 'CMU_016_25',
+ 'CMU_016_26',
+ 'CMU_016_27',
+ 'CMU_016_28',
+ 'CMU_016_29',
+ 'CMU_016_30',
+ 'CMU_016_31',
+ 'CMU_016_32',
+ 'CMU_016_33',
+ 'CMU_016_34',
+ 'CMU_016_47',
+ 'CMU_016_58',
+ 'CMU_047_01',
+ 'CMU_056_01',
+ 'CMU_069_01',
+ 'CMU_069_02',
+ 'CMU_069_03',
+ 'CMU_069_04',
+ 'CMU_069_05',
+ 'CMU_069_20',
+ 'CMU_069_21',
+ 'CMU_069_22',
+ 'CMU_069_23',
+ 'CMU_069_24',
+ 'CMU_069_25',
+ 'CMU_069_26',
+ 'CMU_069_27',
+ 'CMU_069_28',
+ 'CMU_069_29',
+ 'CMU_069_30',
+ 'CMU_069_31',
+ 'CMU_069_32',
+ 'CMU_069_33'))
+
+# Subset of about 2 minutes of walking/running/jumping behaviors.
+RUN_JUMP_TINY = ClipCollection(
+ ids=('CMU_009_01',
+ 'CMU_009_02',
+ 'CMU_009_03',
+ 'CMU_009_04',
+ 'CMU_009_05',
+ 'CMU_009_06',
+ 'CMU_009_07',
+ 'CMU_009_08',
+ 'CMU_009_09',
+ 'CMU_009_10',
+ 'CMU_009_11',
+ 'CMU_016_22',
+ 'CMU_016_23',
+ 'CMU_016_24',
+ 'CMU_016_25',
+ 'CMU_016_26',
+ 'CMU_016_27',
+ 'CMU_016_28',
+ 'CMU_016_29',
+ 'CMU_016_30',
+ 'CMU_016_31',
+ 'CMU_016_32',
+ 'CMU_016_47',
+ 'CMU_016_48',
+ 'CMU_016_49',
+ 'CMU_016_50',
+ 'CMU_016_55',
+ 'CMU_016_58',
+ 'CMU_049_04',
+ 'CMU_049_05',
+ 'CMU_069_01',
+ 'CMU_069_02',
+ 'CMU_069_03',
+ 'CMU_069_04',
+ 'CMU_069_05',
+ 'CMU_075_01',
+ 'CMU_075_02',
+ 'CMU_075_03',
+ 'CMU_075_10',
+ 'CMU_075_11',
+ 'CMU_127_03',
+ 'CMU_127_06',
+ 'CMU_127_07',
+ 'CMU_127_08',
+ 'CMU_127_09',
+ 'CMU_127_10',
+ 'CMU_127_11',
+ 'CMU_127_12',
+ 'CMU_128_02',
+ 'CMU_128_03'))
+
+# Subset of about 3.5 hours of varied locomotion behaviors and hand movements.
+ALL = ClipCollection(
+ ids=('CMU_001_01',
+ 'CMU_002_01',
+ 'CMU_002_02',
+ 'CMU_002_03',
+ 'CMU_002_04',
+ 'CMU_005_01',
+ 'CMU_006_01',
+ 'CMU_006_02',
+ 'CMU_006_03',
+ 'CMU_006_04',
+ 'CMU_006_05',
+ 'CMU_006_06',
+ 'CMU_006_07',
+ 'CMU_006_08',
+ 'CMU_006_09',
+ 'CMU_006_10',
+ 'CMU_006_11',
+ 'CMU_006_12',
+ 'CMU_006_13',
+ 'CMU_006_14',
+ 'CMU_006_15',
+ 'CMU_007_01',
+ 'CMU_007_02',
+ 'CMU_007_03',
+ 'CMU_007_04',
+ 'CMU_007_05',
+ 'CMU_007_06',
+ 'CMU_007_07',
+ 'CMU_007_08',
+ 'CMU_007_09',
+ 'CMU_007_10',
+ 'CMU_007_11',
+ 'CMU_007_12',
+ 'CMU_008_01',
+ 'CMU_008_02',
+ 'CMU_008_03',
+ 'CMU_008_04',
+ 'CMU_008_05',
+ 'CMU_008_06',
+ 'CMU_008_07',
+ 'CMU_008_08',
+ 'CMU_008_09',
+ 'CMU_008_10',
+ 'CMU_008_11',
+ 'CMU_009_01',
+ 'CMU_009_02',
+ 'CMU_009_03',
+ 'CMU_009_04',
+ 'CMU_009_05',
+ 'CMU_009_06',
+ 'CMU_009_07',
+ 'CMU_009_08',
+ 'CMU_009_09',
+ 'CMU_009_10',
+ 'CMU_009_11',
+ 'CMU_009_12',
+ 'CMU_010_04',
+ 'CMU_013_11',
+ 'CMU_013_13',
+ 'CMU_013_19',
+ 'CMU_013_26',
+ 'CMU_013_27',
+ 'CMU_013_28',
+ 'CMU_013_29',
+ 'CMU_013_30',
+ 'CMU_013_31',
+ 'CMU_013_32',
+ 'CMU_013_39',
+ 'CMU_013_40',
+ 'CMU_013_41',
+ 'CMU_013_42',
+ 'CMU_014_06',
+ 'CMU_014_07',
+ 'CMU_014_08',
+ 'CMU_014_09',
+ 'CMU_014_14',
+ 'CMU_014_20',
+ 'CMU_014_24',
+ 'CMU_014_25',
+ 'CMU_014_26',
+ 'CMU_015_01',
+ 'CMU_015_03',
+ 'CMU_015_04',
+ 'CMU_015_05',
+ 'CMU_015_06',
+ 'CMU_015_07',
+ 'CMU_015_08',
+ 'CMU_015_09',
+ 'CMU_015_12',
+ 'CMU_015_14',
+ 'CMU_016_01',
+ 'CMU_016_02',
+ 'CMU_016_03',
+ 'CMU_016_04',
+ 'CMU_016_05',
+ 'CMU_016_06',
+ 'CMU_016_07',
+ 'CMU_016_08',
+ 'CMU_016_09',
+ 'CMU_016_10',
+ 'CMU_016_11',
+ 'CMU_016_12',
+ 'CMU_016_13',
+ 'CMU_016_14',
+ 'CMU_016_15',
+ 'CMU_016_16',
+ 'CMU_016_17',
+ 'CMU_016_18',
+ 'CMU_016_19',
+ 'CMU_016_20',
+ 'CMU_016_21',
+ 'CMU_016_22',
+ 'CMU_016_23',
+ 'CMU_016_24',
+ 'CMU_016_25',
+ 'CMU_016_26',
+ 'CMU_016_27',
+ 'CMU_016_28',
+ 'CMU_016_29',
+ 'CMU_016_30',
+ 'CMU_016_31',
+ 'CMU_016_32',
+ 'CMU_016_33',
+ 'CMU_016_34',
+ 'CMU_016_35',
+ 'CMU_016_36',
+ 'CMU_016_37',
+ 'CMU_016_38',
+ 'CMU_016_39',
+ 'CMU_016_40',
+ 'CMU_016_41',
+ 'CMU_016_42',
+ 'CMU_016_43',
+ 'CMU_016_44',
+ 'CMU_016_45',
+ 'CMU_016_46',
+ 'CMU_016_47',
+ 'CMU_016_48',
+ 'CMU_016_49',
+ 'CMU_016_50',
+ 'CMU_016_51',
+ 'CMU_016_52',
+ 'CMU_016_53',
+ 'CMU_016_54',
+ 'CMU_016_55',
+ 'CMU_016_56',
+ 'CMU_016_57',
+ 'CMU_016_58',
+ 'CMU_017_01',
+ 'CMU_017_02',
+ 'CMU_017_03',
+ 'CMU_017_04',
+ 'CMU_017_05',
+ 'CMU_017_06',
+ 'CMU_017_07',
+ 'CMU_017_08',
+ 'CMU_017_09',
+ 'CMU_017_10',
+ 'CMU_024_01',
+ 'CMU_025_01',
+ 'CMU_026_01',
+ 'CMU_026_02',
+ 'CMU_026_03',
+ 'CMU_026_04',
+ 'CMU_026_05',
+ 'CMU_026_06',
+ 'CMU_026_07',
+ 'CMU_026_08',
+ 'CMU_027_01',
+ 'CMU_027_02',
+ 'CMU_027_03',
+ 'CMU_027_04',
+ 'CMU_027_05',
+ 'CMU_027_06',
+ 'CMU_027_07',
+ 'CMU_027_08',
+ 'CMU_027_09',
+ 'CMU_029_01',
+ 'CMU_029_02',
+ 'CMU_029_03',
+ 'CMU_029_04',
+ 'CMU_029_05',
+ 'CMU_029_06',
+ 'CMU_029_07',
+ 'CMU_029_08',
+ 'CMU_029_09',
+ 'CMU_029_10',
+ 'CMU_029_11',
+ 'CMU_029_12',
+ 'CMU_029_13',
+ 'CMU_031_01',
+ 'CMU_031_02',
+ 'CMU_031_03',
+ 'CMU_031_06',
+ 'CMU_031_07',
+ 'CMU_031_08',
+ 'CMU_032_01',
+ 'CMU_032_02',
+ 'CMU_032_03',
+ 'CMU_032_04',
+ 'CMU_032_05',
+ 'CMU_032_06',
+ 'CMU_032_07',
+ 'CMU_032_08',
+ 'CMU_032_09',
+ 'CMU_032_10',
+ 'CMU_032_11',
+ 'CMU_035_01',
+ 'CMU_035_02',
+ 'CMU_035_03',
+ 'CMU_035_04',
+ 'CMU_035_05',
+ 'CMU_035_06',
+ 'CMU_035_07',
+ 'CMU_035_08',
+ 'CMU_035_09',
+ 'CMU_035_10',
+ 'CMU_035_11',
+ 'CMU_035_12',
+ 'CMU_035_13',
+ 'CMU_035_14',
+ 'CMU_035_15',
+ 'CMU_035_16',
+ 'CMU_035_17',
+ 'CMU_035_18',
+ 'CMU_035_19',
+ 'CMU_035_20',
+ 'CMU_035_21',
+ 'CMU_035_22',
+ 'CMU_035_23',
+ 'CMU_035_24',
+ 'CMU_035_25',
+ 'CMU_035_26',
+ 'CMU_035_27',
+ 'CMU_035_28',
+ 'CMU_035_29',
+ 'CMU_035_30',
+ 'CMU_035_31',
+ 'CMU_035_32',
+ 'CMU_035_33',
+ 'CMU_035_34',
+ 'CMU_036_02',
+ 'CMU_036_03',
+ 'CMU_036_09',
+ 'CMU_037_01',
+ 'CMU_038_01',
+ 'CMU_038_02',
+ 'CMU_038_03',
+ 'CMU_038_04',
+ 'CMU_039_11',
+ 'CMU_040_02',
+ 'CMU_040_03',
+ 'CMU_040_04',
+ 'CMU_040_05',
+ 'CMU_040_10',
+ 'CMU_040_11',
+ 'CMU_040_12',
+ 'CMU_041_02',
+ 'CMU_041_03',
+ 'CMU_041_04',
+ 'CMU_041_05',
+ 'CMU_041_06',
+ 'CMU_041_10',
+ 'CMU_041_11',
+ 'CMU_045_01',
+ 'CMU_046_01',
+ 'CMU_047_01',
+ 'CMU_049_01',
+ 'CMU_049_02',
+ 'CMU_049_03',
+ 'CMU_049_04',
+ 'CMU_049_05',
+ 'CMU_049_06',
+ 'CMU_049_07',
+ 'CMU_049_08',
+ 'CMU_049_09',
+ 'CMU_049_10',
+ 'CMU_049_11',
+ 'CMU_049_12',
+ 'CMU_049_13',
+ 'CMU_049_14',
+ 'CMU_049_15',
+ 'CMU_049_16',
+ 'CMU_049_17',
+ 'CMU_049_18',
+ 'CMU_049_19',
+ 'CMU_049_20',
+ 'CMU_049_22',
+ 'CMU_056_01',
+ 'CMU_056_04',
+ 'CMU_056_05',
+ 'CMU_056_06',
+ 'CMU_056_07',
+ 'CMU_056_08',
+ 'CMU_060_02',
+ 'CMU_060_03',
+ 'CMU_060_05',
+ 'CMU_060_12',
+ 'CMU_060_14',
+ 'CMU_061_01',
+ 'CMU_061_02',
+ 'CMU_061_03',
+ 'CMU_061_04',
+ 'CMU_061_05',
+ 'CMU_061_06',
+ 'CMU_061_07',
+ 'CMU_061_08',
+ 'CMU_061_09',
+ 'CMU_061_10',
+ 'CMU_061_15',
+ 'CMU_069_01',
+ 'CMU_069_02',
+ 'CMU_069_03',
+ 'CMU_069_04',
+ 'CMU_069_05',
+ 'CMU_069_06',
+ 'CMU_069_07',
+ 'CMU_069_08',
+ 'CMU_069_09',
+ 'CMU_069_10',
+ 'CMU_069_11',
+ 'CMU_069_12',
+ 'CMU_069_13',
+ 'CMU_069_14',
+ 'CMU_069_15',
+ 'CMU_069_16',
+ 'CMU_069_17',
+ 'CMU_069_18',
+ 'CMU_069_19',
+ 'CMU_069_20',
+ 'CMU_069_21',
+ 'CMU_069_22',
+ 'CMU_069_23',
+ 'CMU_069_24',
+ 'CMU_069_25',
+ 'CMU_069_26',
+ 'CMU_069_27',
+ 'CMU_069_28',
+ 'CMU_069_29',
+ 'CMU_069_30',
+ 'CMU_069_31',
+ 'CMU_069_32',
+ 'CMU_069_33',
+ 'CMU_069_34',
+ 'CMU_069_36',
+ 'CMU_069_37',
+ 'CMU_069_38',
+ 'CMU_069_39',
+ 'CMU_069_40',
+ 'CMU_069_41',
+ 'CMU_069_42',
+ 'CMU_069_43',
+ 'CMU_069_44',
+ 'CMU_069_45',
+ 'CMU_069_46',
+ 'CMU_069_47',
+ 'CMU_069_48',
+ 'CMU_069_49',
+ 'CMU_069_50',
+ 'CMU_069_51',
+ 'CMU_069_52',
+ 'CMU_069_53',
+ 'CMU_069_54',
+ 'CMU_069_55',
+ 'CMU_069_56',
+ 'CMU_069_57',
+ 'CMU_069_58',
+ 'CMU_069_59',
+ 'CMU_069_60',
+ 'CMU_069_61',
+ 'CMU_069_62',
+ 'CMU_069_63',
+ 'CMU_069_64',
+ 'CMU_069_65',
+ 'CMU_069_66',
+ 'CMU_069_67',
+ 'CMU_075_01',
+ 'CMU_075_02',
+ 'CMU_075_03',
+ 'CMU_075_04',
+ 'CMU_075_05',
+ 'CMU_075_06',
+ 'CMU_075_07',
+ 'CMU_075_08',
+ 'CMU_075_09',
+ 'CMU_075_10',
+ 'CMU_075_11',
+ 'CMU_075_12',
+ 'CMU_075_13',
+ 'CMU_075_14',
+ 'CMU_075_15',
+ 'CMU_076_01',
+ 'CMU_076_02',
+ 'CMU_076_06',
+ 'CMU_076_08',
+ 'CMU_076_09',
+ 'CMU_076_10',
+ 'CMU_076_11',
+ 'CMU_077_02',
+ 'CMU_077_03',
+ 'CMU_077_04',
+ 'CMU_077_10',
+ 'CMU_077_11',
+ 'CMU_077_12',
+ 'CMU_077_13',
+ 'CMU_077_14',
+ 'CMU_077_15',
+ 'CMU_077_16',
+ 'CMU_077_17',
+ 'CMU_077_18',
+ 'CMU_077_21',
+ 'CMU_077_27',
+ 'CMU_077_28',
+ 'CMU_077_29',
+ 'CMU_077_30',
+ 'CMU_077_31',
+ 'CMU_077_32',
+ 'CMU_077_33',
+ 'CMU_077_34',
+ 'CMU_078_01',
+ 'CMU_078_02',
+ 'CMU_078_03',
+ 'CMU_078_04',
+ 'CMU_078_05',
+ 'CMU_078_07',
+ 'CMU_078_09',
+ 'CMU_078_10',
+ 'CMU_078_13',
+ 'CMU_078_14',
+ 'CMU_078_15',
+ 'CMU_078_16',
+ 'CMU_078_17',
+ 'CMU_078_18',
+ 'CMU_078_19',
+ 'CMU_078_20',
+ 'CMU_078_21',
+ 'CMU_078_22',
+ 'CMU_078_23',
+ 'CMU_078_24',
+ 'CMU_078_25',
+ 'CMU_078_26',
+ 'CMU_078_27',
+ 'CMU_078_28',
+ 'CMU_078_29',
+ 'CMU_078_30',
+ 'CMU_078_31',
+ 'CMU_078_32',
+ 'CMU_078_33',
+ 'CMU_082_08',
+ 'CMU_082_09',
+ 'CMU_082_10',
+ 'CMU_082_11',
+ 'CMU_082_14',
+ 'CMU_082_15',
+ 'CMU_083_18',
+ 'CMU_083_19',
+ 'CMU_083_20',
+ 'CMU_083_21',
+ 'CMU_083_33',
+ 'CMU_083_36',
+ 'CMU_083_37',
+ 'CMU_083_38',
+ 'CMU_083_39',
+ 'CMU_083_40',
+ 'CMU_083_41',
+ 'CMU_083_42',
+ 'CMU_083_43',
+ 'CMU_083_44',
+ 'CMU_083_45',
+ 'CMU_083_46',
+ 'CMU_083_48',
+ 'CMU_083_49',
+ 'CMU_083_51',
+ 'CMU_083_52',
+ 'CMU_083_53',
+ 'CMU_083_54',
+ 'CMU_083_55',
+ 'CMU_083_56',
+ 'CMU_083_57',
+ 'CMU_083_58',
+ 'CMU_083_59',
+ 'CMU_083_60',
+ 'CMU_083_61',
+ 'CMU_083_62',
+ 'CMU_083_63',
+ 'CMU_083_64',
+ 'CMU_083_65',
+ 'CMU_083_66',
+ 'CMU_083_67',
+ 'CMU_086_01',
+ 'CMU_086_02',
+ 'CMU_086_03',
+ 'CMU_086_07',
+ 'CMU_086_08',
+ 'CMU_086_11',
+ 'CMU_086_14',
+ 'CMU_090_06',
+ 'CMU_090_07',
+ 'CMU_091_01',
+ 'CMU_091_02',
+ 'CMU_091_03',
+ 'CMU_091_04',
+ 'CMU_091_05',
+ 'CMU_091_06',
+ 'CMU_091_07',
+ 'CMU_091_08',
+ 'CMU_091_10',
+ 'CMU_091_11',
+ 'CMU_091_12',
+ 'CMU_091_13',
+ 'CMU_091_14',
+ 'CMU_091_15',
+ 'CMU_091_16',
+ 'CMU_091_17',
+ 'CMU_091_18',
+ 'CMU_091_19',
+ 'CMU_091_20',
+ 'CMU_091_21',
+ 'CMU_091_22',
+ 'CMU_091_23',
+ 'CMU_091_24',
+ 'CMU_091_25',
+ 'CMU_091_26',
+ 'CMU_091_27',
+ 'CMU_091_28',
+ 'CMU_091_29',
+ 'CMU_091_30',
+ 'CMU_091_31',
+ 'CMU_091_32',
+ 'CMU_091_33',
+ 'CMU_091_34',
+ 'CMU_091_35',
+ 'CMU_091_36',
+ 'CMU_091_37',
+ 'CMU_091_38',
+ 'CMU_091_39',
+ 'CMU_091_40',
+ 'CMU_091_41',
+ 'CMU_091_42',
+ 'CMU_091_43',
+ 'CMU_091_44',
+ 'CMU_091_45',
+ 'CMU_091_46',
+ 'CMU_091_47',
+ 'CMU_091_48',
+ 'CMU_091_49',
+ 'CMU_091_50',
+ 'CMU_091_51',
+ 'CMU_091_52',
+ 'CMU_091_53',
+ 'CMU_091_54',
+ 'CMU_091_55',
+ 'CMU_091_56',
+ 'CMU_091_57',
+ 'CMU_091_58',
+ 'CMU_091_59',
+ 'CMU_091_60',
+ 'CMU_091_61',
+ 'CMU_091_62',
+ 'CMU_104_53',
+ 'CMU_104_54',
+ 'CMU_104_55',
+ 'CMU_104_56',
+ 'CMU_104_57',
+ 'CMU_105_01',
+ 'CMU_105_02',
+ 'CMU_105_03',
+ 'CMU_105_04',
+ 'CMU_105_05',
+ 'CMU_105_07',
+ 'CMU_105_08',
+ 'CMU_105_10',
+ 'CMU_105_17',
+ 'CMU_105_18',
+ 'CMU_105_19',
+ 'CMU_105_20',
+ 'CMU_105_22',
+ 'CMU_105_29',
+ 'CMU_105_31',
+ 'CMU_105_34',
+ 'CMU_105_36',
+ 'CMU_105_37',
+ 'CMU_105_38',
+ 'CMU_105_39',
+ 'CMU_105_40',
+ 'CMU_105_41',
+ 'CMU_105_42',
+ 'CMU_105_43',
+ 'CMU_105_44',
+ 'CMU_105_45',
+ 'CMU_105_46',
+ 'CMU_105_47',
+ 'CMU_105_48',
+ 'CMU_105_49',
+ 'CMU_105_50',
+ 'CMU_105_51',
+ 'CMU_105_52',
+ 'CMU_105_53',
+ 'CMU_105_54',
+ 'CMU_105_55',
+ 'CMU_105_56',
+ 'CMU_105_57',
+ 'CMU_105_58',
+ 'CMU_105_59',
+ 'CMU_105_60',
+ 'CMU_105_61',
+ 'CMU_105_62',
+ 'CMU_107_01',
+ 'CMU_107_02',
+ 'CMU_107_03',
+ 'CMU_107_04',
+ 'CMU_107_05',
+ 'CMU_107_06',
+ 'CMU_107_07',
+ 'CMU_107_08',
+ 'CMU_107_09',
+ 'CMU_107_11',
+ 'CMU_107_12',
+ 'CMU_107_13',
+ 'CMU_107_14',
+ 'CMU_108_01',
+ 'CMU_108_02',
+ 'CMU_108_03',
+ 'CMU_108_04',
+ 'CMU_108_05',
+ 'CMU_108_06',
+ 'CMU_108_07',
+ 'CMU_108_08',
+ 'CMU_108_09',
+ 'CMU_108_12',
+ 'CMU_108_13',
+ 'CMU_108_14',
+ 'CMU_108_17',
+ 'CMU_108_18',
+ 'CMU_108_19',
+ 'CMU_108_20',
+ 'CMU_108_21',
+ 'CMU_108_22',
+ 'CMU_108_23',
+ 'CMU_108_24',
+ 'CMU_108_25',
+ 'CMU_108_26',
+ 'CMU_108_27',
+ 'CMU_108_28',
+ 'CMU_114_13',
+ 'CMU_114_14',
+ 'CMU_114_15',
+ 'CMU_118_01',
+ 'CMU_118_02',
+ 'CMU_118_03',
+ 'CMU_118_04',
+ 'CMU_118_05',
+ 'CMU_118_06',
+ 'CMU_118_07',
+ 'CMU_118_08',
+ 'CMU_118_09',
+ 'CMU_118_10',
+ 'CMU_118_11',
+ 'CMU_118_12',
+ 'CMU_118_13',
+ 'CMU_118_14',
+ 'CMU_118_15',
+ 'CMU_118_16',
+ 'CMU_118_17',
+ 'CMU_118_18',
+ 'CMU_118_19',
+ 'CMU_118_20',
+ 'CMU_118_21',
+ 'CMU_118_22',
+ 'CMU_118_23',
+ 'CMU_118_24',
+ 'CMU_118_25',
+ 'CMU_118_26',
+ 'CMU_118_27',
+ 'CMU_118_28',
+ 'CMU_118_29',
+ 'CMU_118_30',
+ 'CMU_118_32',
+ 'CMU_120_20',
+ 'CMU_124_03',
+ 'CMU_124_04',
+ 'CMU_124_05',
+ 'CMU_124_06',
+ 'CMU_127_02',
+ 'CMU_127_03',
+ 'CMU_127_04',
+ 'CMU_127_05',
+ 'CMU_127_06',
+ 'CMU_127_07',
+ 'CMU_127_08',
+ 'CMU_127_09',
+ 'CMU_127_10',
+ 'CMU_127_11',
+ 'CMU_127_12',
+ 'CMU_127_13',
+ 'CMU_127_14',
+ 'CMU_127_15',
+ 'CMU_127_16',
+ 'CMU_127_17',
+ 'CMU_127_18',
+ 'CMU_127_19',
+ 'CMU_127_20',
+ 'CMU_127_21',
+ 'CMU_127_22',
+ 'CMU_127_23',
+ 'CMU_127_24',
+ 'CMU_127_25',
+ 'CMU_127_26',
+ 'CMU_127_27',
+ 'CMU_127_28',
+ 'CMU_127_29',
+ 'CMU_127_30',
+ 'CMU_127_31',
+ 'CMU_127_32',
+ 'CMU_127_37',
+ 'CMU_127_38',
+ 'CMU_128_02',
+ 'CMU_128_03',
+ 'CMU_128_04',
+ 'CMU_128_05',
+ 'CMU_128_06',
+ 'CMU_128_07',
+ 'CMU_128_08',
+ 'CMU_128_09',
+ 'CMU_128_10',
+ 'CMU_128_11',
+ 'CMU_132_01',
+ 'CMU_132_02',
+ 'CMU_132_03',
+ 'CMU_132_04',
+ 'CMU_132_05',
+ 'CMU_132_06',
+ 'CMU_132_07',
+ 'CMU_132_08',
+ 'CMU_132_09',
+ 'CMU_132_10',
+ 'CMU_132_11',
+ 'CMU_132_12',
+ 'CMU_132_13',
+ 'CMU_132_14',
+ 'CMU_132_15',
+ 'CMU_132_16',
+ 'CMU_132_17',
+ 'CMU_132_18',
+ 'CMU_132_19',
+ 'CMU_132_20',
+ 'CMU_132_21',
+ 'CMU_132_22',
+ 'CMU_132_23',
+ 'CMU_132_24',
+ 'CMU_132_25',
+ 'CMU_132_26',
+ 'CMU_132_27',
+ 'CMU_132_28',
+ 'CMU_132_29',
+ 'CMU_132_30',
+ 'CMU_132_31',
+ 'CMU_132_32',
+ 'CMU_132_33',
+ 'CMU_132_34',
+ 'CMU_132_35',
+ 'CMU_132_36',
+ 'CMU_132_37',
+ 'CMU_132_38',
+ 'CMU_132_39',
+ 'CMU_132_40',
+ 'CMU_132_41',
+ 'CMU_132_42',
+ 'CMU_132_43',
+ 'CMU_132_44',
+ 'CMU_132_45',
+ 'CMU_132_46',
+ 'CMU_132_47',
+ 'CMU_132_48',
+ 'CMU_132_49',
+ 'CMU_132_50',
+ 'CMU_132_51',
+ 'CMU_132_52',
+ 'CMU_132_53',
+ 'CMU_132_54',
+ 'CMU_132_55',
+ 'CMU_133_03',
+ 'CMU_133_04',
+ 'CMU_133_05',
+ 'CMU_133_06',
+ 'CMU_133_07',
+ 'CMU_133_08',
+ 'CMU_133_10',
+ 'CMU_133_11',
+ 'CMU_133_12',
+ 'CMU_133_13',
+ 'CMU_133_14',
+ 'CMU_133_15',
+ 'CMU_133_16',
+ 'CMU_133_17',
+ 'CMU_133_18',
+ 'CMU_133_19',
+ 'CMU_133_20',
+ 'CMU_133_21',
+ 'CMU_133_22',
+ 'CMU_133_23',
+ 'CMU_133_24',
+ 'CMU_139_04',
+ 'CMU_139_10',
+ 'CMU_139_11',
+ 'CMU_139_12',
+ 'CMU_139_13',
+ 'CMU_139_14',
+ 'CMU_139_15',
+ 'CMU_139_16',
+ 'CMU_139_17',
+ 'CMU_139_18',
+ 'CMU_139_21',
+ 'CMU_139_28',
+ 'CMU_140_01',
+ 'CMU_140_02',
+ 'CMU_140_08',
+ 'CMU_140_09',
+ 'CMU_143_01',
+ 'CMU_143_02',
+ 'CMU_143_03',
+ 'CMU_143_04',
+ 'CMU_143_05',
+ 'CMU_143_06',
+ 'CMU_143_07',
+ 'CMU_143_08',
+ 'CMU_143_09',
+ 'CMU_143_14',
+ 'CMU_143_15',
+ 'CMU_143_16',
+ 'CMU_143_29',
+ 'CMU_143_32',
+ 'CMU_143_39',
+ 'CMU_143_40',
+ 'CMU_143_41',
+ 'CMU_143_42'))
+
+
+CMU_SUBSETS_DICT = dict(
+ walk_tiny=WALK_TINY,
+ run_jump_tiny=RUN_JUMP_TINY,
+ get_up=GET_UP,
+ locomotion_small=LOCOMOTION_SMALL,
+ all=ALL
+ )
diff --git a/dm_control/locomotion/tasks/reference_pose/datasets.py b/dm_control/locomotion/tasks/reference_pose/datasets.py
new file mode 100644
index 00000000..03023f8c
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/datasets.py
@@ -0,0 +1,22 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Datasets for reference pose tasks.
+"""
+
+from dm_control.locomotion.tasks.reference_pose import cmu_subsets
+
+
+DATASETS = dict()
+DATASETS.update(cmu_subsets.CMU_SUBSETS_DICT)
diff --git a/dm_control/locomotion/tasks/reference_pose/mocap_playback.py b/dm_control/locomotion/tasks/reference_pose/mocap_playback.py
new file mode 100644
index 00000000..8ea21f0b
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/mocap_playback.py
@@ -0,0 +1,63 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Simple script to visualize motion capture data."""
+
+from absl import app
+
+from dm_control import composer
+from dm_control import viewer
+
+from dm_control.locomotion import arenas
+from dm_control.locomotion import walkers
+
+from dm_control.locomotion.mocap import cmu_mocap_data
+from dm_control.locomotion.tasks.reference_pose import tracking
+
+
+def mocap_playback_env(random_state=None):
+ """Constructs mocap playback environment."""
+
+ # Use a position-controlled CMU humanoid walker.
+ walker_type = walkers.CMUHumanoidPositionControlledV2020
+
+ # Build an empty arena.
+ arena = arenas.Floor()
+
+ # Build a task that rewards the agent for tracking motion capture reference
+ # data.
+ task = tracking.PlaybackTask(
+ walker=walker_type,
+ arena=arena,
+ ref_path=cmu_mocap_data.get_path_for_cmu(version='2020'),
+ dataset='run_jump_tiny',
+ )
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def main(unused_argv):
+ # The viewer calls the environment_loader on episode resets. However the task
+ # cycles through one clip per episode. To avoid replaying the first clip again
+ #Â and again we construct the environment outside the viewer to make it
+ # persistent across resets.
+ env = mocap_playback_env()
+ viewer.launch(environment_loader=lambda: env)
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/locomotion/tasks/reference_pose/rewards.py b/dm_control/locomotion/tasks/reference_pose/rewards.py
new file mode 100644
index 00000000..f2b34f1f
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/rewards.py
@@ -0,0 +1,187 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Define reward function options for reference pose tasks."""
+
+import collections
+
+import numpy as np
+
+RewardFnOutput = collections.namedtuple('RewardFnOutput',
+ ['reward', 'debug', 'reward_terms'])
+
+
+def bounded_quat_dist(source: np.ndarray,
+ target: np.ndarray) -> np.ndarray:
+ """Computes a quaternion distance limiting the difference to a max of pi/2.
+
+ This function supports an arbitrary number of batch dimensions, B.
+
+ Args:
+ source: a quaternion, shape (B, 4).
+ target: another quaternion, shape (B, 4).
+
+ Returns:
+ Quaternion distance, shape (B, 1).
+ """
+ source /= np.linalg.norm(source, axis=-1, keepdims=True)
+ target /= np.linalg.norm(target, axis=-1, keepdims=True)
+ # "Distance" in interval [-1, 1].
+ dist = 2 * np.einsum('...i,...i', source, target) ** 2 - 1
+ # Clip at 1 to avoid occasional machine epsilon leak beyond 1.
+ dist = np.minimum(1., dist)
+ # Divide by 2 and add an axis to ensure consistency with expected return
+ # shape and magnitude.
+ return 0.5 * np.arccos(dist)[..., np.newaxis]
+
+
+def sort_dict(d):
+ return collections.OrderedDict(sorted(d.items()))
+
+
+def compute_squared_differences(walker_features, reference_features,
+ exclude_keys=()):
+ """Computes squared differences of features."""
+ squared_differences = {}
+ for k in walker_features:
+ if k not in exclude_keys:
+ if 'quaternion' not in k:
+ squared_differences[k] = np.sum(
+ (walker_features[k] - reference_features[k])**2)
+ elif 'quaternions' in k:
+ quat_dists = bounded_quat_dist(
+ walker_features[k], reference_features[k])
+ squared_differences[k] = np.sum(quat_dists**2)
+ else:
+ squared_differences[k] = bounded_quat_dist(
+ walker_features[k], reference_features[k])**2
+
+ return squared_differences
+
+
+def termination_reward_fn(termination_error, termination_error_threshold,
+ **unused_kwargs):
+ """Termination error.
+
+ This reward is intended to be used in conjunction with the termination error
+ calculated in the task. Due to terminations if error > error_threshold this
+ reward will be in [0, 1].
+
+ Args:
+ termination_error: termination error computed in tracking task
+ termination_error_threshold: task termination threshold
+ unused_kwargs: unused_kwargs
+
+ Returns:
+ RewardFnOutput tuple containing reward, debug information and reward terms.
+ """
+ debug_terms = {
+ 'termination_error': termination_error,
+ 'termination_error_threshold': termination_error_threshold
+ }
+ termination_reward = 1 - termination_error / termination_error_threshold
+ return RewardFnOutput(reward=termination_reward, debug=debug_terms,
+ reward_terms=sort_dict(
+ {'termination': termination_reward}))
+
+
+def debug(reference_features, walker_features, **unused_kwargs):
+ debug_terms = compute_squared_differences(walker_features, reference_features)
+ return RewardFnOutput(reward=0.0, debug=debug_terms, reward_terms=None)
+
+
+def multi_term_pose_reward_fn(walker_features, reference_features,
+ **unused_kwargs):
+ """A reward based on com, body quaternions, joints velocities & appendages."""
+ differences = compute_squared_differences(walker_features, reference_features)
+ com = .1 * np.exp(-10 * differences['center_of_mass'])
+ joints_velocity = 1.0 * np.exp(-0.1 * differences['joints_velocity'])
+ appendages = 0.15 * np.exp(-40. * differences['appendages'])
+ body_quaternions = 0.65 * np.exp(-2 * differences['body_quaternions'])
+ terms = {
+ 'center_of_mass': com,
+ 'joints_velocity': joints_velocity,
+ 'appendages': appendages,
+ 'body_quaternions': body_quaternions
+ }
+ reward = sum(terms.values())
+ return RewardFnOutput(reward=reward, debug=terms,
+ reward_terms=sort_dict(terms))
+
+
+def comic_reward_fn(termination_error, termination_error_threshold,
+ walker_features, reference_features, **unused_kwargs):
+ """A reward that mixes the termination_reward and multi_term_pose_reward.
+
+ This reward function was used in
+ Hasenclever et al.,
+ CoMic: Complementary Task Learning & Mimicry for Reusable Skills,
+ International Conference on Machine Learning, 2020.
+ [https://proceedings.icml.cc/static/paper_files/icml/2020/5013-Paper.pdf]
+
+ Args:
+ termination_error: termination error as described
+ termination_error_threshold: threshold to determine whether to terminate
+ episodes. The threshold is used to construct a reward between [0, 1]
+ based on the termination error.
+ walker_features: Current features of the walker
+ reference_features: features of the current reference pose
+ unused_kwargs: unused addtional keyword arguments.
+
+ Returns:
+ RewardFnOutput tuple containing reward, debug terms and reward terms.
+ """
+ termination_reward, debug_terms, termination_reward_terms = (
+ termination_reward_fn(termination_error, termination_error_threshold))
+ mt_reward, mt_debug_terms, mt_reward_terms = multi_term_pose_reward_fn(
+ walker_features, reference_features)
+ debug_terms.update(mt_debug_terms)
+ reward_terms = {k: 0.5 * v for k, v in termination_reward_terms.items()}
+ reward_terms.update(
+ {k: 0.5 * v for k, v in mt_reward_terms.items()})
+ return RewardFnOutput(
+ reward=0.5 * termination_reward + 0.5 * mt_reward,
+ debug=debug_terms,
+ reward_terms=sort_dict(reward_terms))
+
+
+_REWARD_FN = {
+ 'termination_reward': termination_reward_fn,
+ 'multi_term_pose_reward': multi_term_pose_reward_fn,
+ 'comic': comic_reward_fn,
+}
+
+_REWARD_CHANNELS = {
+ 'termination_reward': ('termination',),
+ 'multi_term_pose_reward':
+ ('appendages', 'body_quaternions', 'center_of_mass', 'joints_velocity'),
+ 'comic': ('appendages', 'body_quaternions', 'center_of_mass', 'termination',
+ 'joints_velocity'),
+}
+
+
+def get_reward(reward_key):
+ if reward_key not in _REWARD_FN:
+ raise ValueError('Requested loss %s, which is not a valid option.' %
+ reward_key)
+
+ return _REWARD_FN[reward_key]
+
+
+def get_reward_channels(reward_key):
+ if reward_key not in _REWARD_CHANNELS:
+ raise ValueError('Requested loss %s, which is not a valid option.' %
+ reward_key)
+
+ return _REWARD_CHANNELS[reward_key]
diff --git a/dm_control/locomotion/tasks/reference_pose/rewards_test.py b/dm_control/locomotion/tasks/reference_pose/rewards_test.py
new file mode 100644
index 00000000..17ce7575
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/rewards_test.py
@@ -0,0 +1,84 @@
+# Copyright 2021 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for dm_control.locomotion.tasks.reference_pose.rewards."""
+
+from absl.testing import absltest
+from dm_control.locomotion.tasks.reference_pose import rewards
+import numpy as np
+
+WALKER_FEATURES = {
+ 'scalar': 0.,
+ 'vector': np.ones(3),
+ 'match': 0.1,
+}
+
+REFERENCE_FEATURES = {
+ 'scalar': 1.5,
+ 'vector': np.full(3, 2),
+ 'match': 0.1,
+}
+
+QUATERNION_FEATURES = {
+ 'unmatched_quaternion': (1., 0., 0., 0.),
+ 'matched_quaternions': [(1., 0., 1., 0.), (0.707, 0.707, 0., 0.)],
+}
+
+REFERENCE_QUATERNION_FEATURES = {
+ 'unmatched_quaternion': (0., 0., 0., 1.),
+ 'matched_quaternions': [(1., 0., 1., 0.), (0.707, 0.707, 0., 0.)],
+}
+
+
+EXPECTED_DIFFERENCES = {
+ 'scalar': 2.25,
+ 'vector': 3.,
+ 'match': 0.,
+ 'unmatched_quaternion': np.sum(rewards.bounded_quat_dist(
+ QUATERNION_FEATURES['unmatched_quaternion'],
+ REFERENCE_QUATERNION_FEATURES['unmatched_quaternion']))**2,
+ 'matched_quaternions': 0.,
+}
+
+EXCLUDE_KEYS = ('scalar', 'match')
+
+
+class RewardsTest(absltest.TestCase):
+
+ def test_compute_squared_differences(self):
+ """Basic usage."""
+ differences = rewards.compute_squared_differences(
+ WALKER_FEATURES, REFERENCE_FEATURES)
+ for key, difference in differences.items():
+ self.assertEqual(difference, EXPECTED_DIFFERENCES[key])
+
+ def test_compute_squared_differences_exclude_keys(self):
+ """Test excluding some keys from squared difference computation."""
+ differences = rewards.compute_squared_differences(
+ WALKER_FEATURES, REFERENCE_FEATURES, exclude_keys=EXCLUDE_KEYS)
+ for key in EXCLUDE_KEYS:
+ self.assertNotIn(key, differences)
+
+ def test_compute_squared_differences_quaternion(self):
+ """Test that quaternions use a different distance computation."""
+
+ differences = rewards.compute_squared_differences(
+ QUATERNION_FEATURES, REFERENCE_QUATERNION_FEATURES)
+
+ for key, difference in differences.items():
+ self.assertAlmostEqual(difference, EXPECTED_DIFFERENCES[key])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/tasks/reference_pose/tracking.py b/dm_control/locomotion/tasks/reference_pose/tracking.py
new file mode 100644
index 00000000..3c28287b
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/tracking.py
@@ -0,0 +1,1007 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tasks for multi-clip mocap tracking with RL."""
+
+import abc
+import collections
+
+import typing
+from typing import Any, Callable, Mapping, Optional, Sequence, Set, Text, Union
+
+from absl import logging
+from dm_control import composer
+from dm_control.composer.observation import observable as base_observable
+from dm_control.locomotion.mocap import loader
+
+from dm_control.locomotion.tasks.reference_pose import datasets
+from dm_control.locomotion.tasks.reference_pose import types
+from dm_control.locomotion.tasks.reference_pose import utils
+from dm_control.locomotion.tasks.reference_pose import rewards
+
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.utils import transformations as tr
+
+from dm_env import specs
+
+import numpy as np
+import tree
+
+if typing.TYPE_CHECKING:
+ from dm_control.locomotion.walkers import legacy_base
+ from dm_control import mjcf
+
+mjlib = mjbindings.mjlib
+DEFAULT_PHYSICS_TIMESTEP = 0.005
+_MAX_END_STEP = 10000
+
+
+def _strip_reference_prefix(dictionary: Mapping[Text, Any],
+ prefix: Text,
+ keep_prefixes: Optional[Set[Text]] = None):
+ """Strips a prefix from dictionary keys and remove keys without the prefix.
+
+ Strips a prefix from the keys of a dictionary and removes any key from the
+ result dictionary that doesn't match the determined prefix, unless explicitly
+ excluded in keep_prefixes.
+
+ E.g.
+ dictionary={
+ 'example_key': 1,
+ 'example_another_key': 2,
+ 'doesnt_match': 3,
+ 'keep_this': 4,
+ }, prefix='example_', keep_prefixes=['keep_']
+
+ would return
+ {
+ 'key': 1,
+ 'another_key': 2,
+ 'keep_this': 4,
+ }
+
+ Args:
+ dictionary: The dictionary whose keys will be stripped.
+ prefix: The prefix to strip.
+ keep_prefixes: Optionally specify prefixes for keys that will be unchanged
+ and retained in the result dictionary.
+
+ Returns:
+ The dictionary with the modified keys and original values (and unchanged
+ keys specified by keep_prefixes).
+ """
+ keep_prefixes = keep_prefixes or []
+ new_dictionary = dict()
+ for key in dictionary:
+ if key.startswith(prefix):
+ key_without_prefix = key.split(prefix)[1]
+ # note that this will not copy the underlying array.
+ new_dictionary[key_without_prefix] = dictionary[key]
+ else:
+ for keep_prefix in keep_prefixes:
+ if key.startswith(keep_prefix):
+ new_dictionary[key] = dictionary[key]
+
+ return new_dictionary
+
+
+class ReferencePosesTask(composer.Task, metaclass=abc.ABCMeta):
+ """Abstract base class for task that uses reference data."""
+
+ def __init__(
+ self,
+ walker: Callable[..., 'legacy_base.Walker'],
+ arena: composer.Arena,
+ ref_path: Text,
+ ref_steps: Sequence[int],
+ dataset: Union[Text, types.ClipCollection],
+ termination_error_threshold: float = 0.3,
+ prop_termination_error_threshold: float = 0.1,
+ min_steps: int = 10,
+ reward_type: Text = 'termination_reward',
+ physics_timestep: float = DEFAULT_PHYSICS_TIMESTEP,
+ always_init_at_clip_start: bool = False,
+ proto_modifier: Optional[Any] = None,
+ prop_factory: Optional[Any] = None,
+ disable_props: bool = False,
+ ghost_offset: Optional[Sequence[Union[int, float]]] = None,
+ body_error_multiplier: Union[int, float] = 1.0,
+ actuator_force_coeff: float = 0.015,
+ enabled_reference_observables: Optional[Sequence[Text]] = None,
+ ):
+ """Abstract task that uses reference data.
+
+ Args:
+ walker: Walker constructor to be used.
+ arena: Arena to be used.
+ ref_path: Path to the dataset containing reference poses.
+ ref_steps: tuples of indices of reference observation. E.g if
+ ref_steps=(1, 2, 3) the walker/reference observation at time t will
+ contain information from t+1, t+2, t+3.
+ dataset: A ClipCollection instance or a name of a dataset that appears as
+ a key in DATASETS in datasets.py
+ termination_error_threshold: Error threshold for episode terminations for
+ hand body position and joint error only.
+ prop_termination_error_threshold: Error threshold for episode terminations
+ for prop position.
+ min_steps: minimum number of steps within an episode. This argument
+ determines the latest allowable starting point within a given reference
+ trajectory.
+ reward_type: type of reward to use, must be a string that appears as a key
+ in the REWARD_FN dict in rewards.py.
+ physics_timestep: Physics timestep to use for simulation.
+ always_init_at_clip_start: only initialize epsidodes at the start of a
+ reference trajectory.
+ proto_modifier: Optional proto modifier to modify reference trajectories,
+ e.g. adding a vertical offset.
+ prop_factory: Optional function that takes the mocap proto and returns
+ the corresponding props for the trajectory.
+ disable_props: If prop_factory is specified but disable_props is True,
+ no props will be created.
+ ghost_offset: if not None, include a ghost rendering of the walker with
+ the reference pose at the specified position offset.
+ body_error_multiplier: A multiplier that is applied to the body error term
+ when determining failure termination condition.
+ actuator_force_coeff: A coefficient for the actuator force reward channel.
+ enabled_reference_observables: Optional iterable of enabled observables.
+ If not specified, a reasonable default set will be enabled.
+ """
+ self._ref_steps = np.sort(ref_steps)
+ self._max_ref_step = self._ref_steps[-1]
+ self._termination_error_threshold = termination_error_threshold
+ self._prop_termination_error_threshold = prop_termination_error_threshold
+ self._reward_fn = rewards.get_reward(reward_type)
+ self._reward_keys = rewards.get_reward_channels(reward_type)
+ self._min_steps = min_steps
+ self._always_init_at_clip_start = always_init_at_clip_start
+ self._ghost_offset = ghost_offset
+ self._body_error_multiplier = body_error_multiplier
+ self._actuator_force_coeff = actuator_force_coeff
+ logging.info('Reward type %s', reward_type)
+
+ if isinstance(dataset, Text):
+ try:
+ dataset = datasets.DATASETS[dataset]
+ except KeyError:
+ logging.error('Dataset %s not found in datasets.py', dataset)
+ raise
+ self._load_reference_data(
+ ref_path=ref_path, proto_modifier=proto_modifier, dataset=dataset)
+
+ self._get_possible_starts()
+
+ logging.info('%d starting points found.', len(self._possible_starts))
+
+ # load a dummy trajectory
+ self._current_clip_index = 0
+ self._current_clip = self._loader.get_trajectory(
+ self._dataset.ids[0], zero_out_velocities=False)
+ # Create the environment.
+ self._arena = arena
+ self._walker = utils.add_walker(walker, self._arena)
+ self.set_timesteps(
+ physics_timestep=physics_timestep,
+ control_timestep=self._current_clip.dt)
+
+ # Identify the desired body components.
+ try:
+ walker_bodies = self._walker.mocap_tracking_bodies
+ except AttributeError:
+ logging.info('Walker must implement mocap bodies for this task.')
+ raise
+
+ walker_bodies_names = [bdy.name for bdy in walker_bodies]
+ self._body_idxs = np.array(
+ [walker_bodies_names.index(bdy) for bdy in walker_bodies_names])
+
+ self._prop_factory = prop_factory
+ if disable_props:
+ self._props = []
+ else:
+ self._props = self._current_clip.create_props(prop_factory=prop_factory)
+ for prop in self._props:
+ self._arena.add_free_entity(prop)
+
+ # Create the observables.
+ self._add_observables(enabled_reference_observables)
+
+ # initialize counters etc.
+ self._time_step = 0
+ self._current_start_time = 0.0
+ self._last_step = 0
+ self._current_clip_index = 0
+ self._reference_observations = dict()
+ self._end_mocap = False
+ self._should_truncate = False
+
+ # Set up required dummy quantities for observations
+ self._prop_prefixes = []
+
+ self._disable_props = disable_props
+ if not disable_props:
+ if len(self._props) == 1:
+ self._prop_prefixes += ['prop/']
+ else:
+ self._prop_prefixes += [f'prop_{i}/' for i in range(len(self._props))]
+ self._clip_reference_features = self._current_clip.as_dict()
+ self._strip_reference_prefix()
+
+ self._walker_joints = self._clip_reference_features['joints'][0]
+ self._walker_features = tree.map_structure(lambda x: x[0],
+ self._clip_reference_features)
+ self._walker_features_prev = tree.map_structure(
+ lambda x: x[0], self._clip_reference_features)
+
+ self._current_reference_features = dict()
+ self._reference_ego_bodies_quats = collections.defaultdict(dict)
+ # if requested add ghost body to visualize motion capture reference.
+ if self._ghost_offset is not None:
+ self._ghost = utils.add_walker(
+ walker, self._arena, name='ghost', ghost=True)
+ self._ghost.observables.disable_all()
+
+ if disable_props:
+ self._ghost_props = []
+ else:
+ self._ghost_props = self._current_clip.create_props(
+ prop_factory=self._ghost_prop_factory)
+ for prop in self._ghost_props:
+ self._arena.add_free_entity(prop)
+ prop.observables.disable_all()
+ else:
+ self._ghost_props = []
+
+ # initialize reward channels
+ self._reset_reward_channels()
+
+ def _strip_reference_prefix(self):
+ self._clip_reference_features = _strip_reference_prefix( # pytype: disable=wrong-arg-types
+ self._clip_reference_features,
+ 'walker/',
+ keep_prefixes=self._prop_prefixes)
+
+ positions = []
+ quaternions = []
+ for prefix in self._prop_prefixes:
+ position_key, quaternion_key = f'{prefix}position', f'{prefix}quaternion'
+ positions.append(self._clip_reference_features[position_key])
+ quaternions.append(self._clip_reference_features[quaternion_key])
+ del self._clip_reference_features[position_key]
+ del self._clip_reference_features[quaternion_key]
+ # positions has dimension (#props, #timesteps, 3). However, the convention
+ # for reference observations is (#timesteps, #props, 3). Therefore we
+ # transpose the dimensions by specifying the desired positions in the list
+ # for each dimension as an argument to np.transpose.
+ axes = [1, 0, 2]
+ if self._prop_prefixes:
+ self._clip_reference_features['prop_positions'] = np.transpose(
+ positions, axes=axes)
+ self._clip_reference_features['prop_quaternions'] = np.transpose(
+ quaternions, axes=axes)
+
+ def _ghost_prop_factory(self, prop_proto, priority_friction=False):
+ if self._prop_factory is None:
+ return None
+
+ prop = self._prop_factory(prop_proto, priority_friction=priority_friction)
+ for geom in prop.mjcf_model.find_all('geom'):
+ geom.set_attributes(contype=0, conaffinity=0, rgba=(0.5, 0.5, 0.5, .999))
+ prop.observables.disable_all()
+ return prop
+
+ def _load_reference_data(self, ref_path, proto_modifier,
+ dataset: types.ClipCollection):
+ self._loader = loader.HDF5TrajectoryLoader(
+ ref_path, proto_modifier=proto_modifier)
+
+ self._dataset = dataset
+ self._num_clips = len(self._dataset.ids)
+
+ if self._dataset.end_steps is None:
+ # load all trajectories to infer clip end steps.
+ self._all_clips = [
+ self._loader.get_trajectory( # pylint: disable=g-complex-comprehension
+ clip_id,
+ start_step=clip_start_step,
+ end_step=_MAX_END_STEP) for clip_id, clip_start_step in zip(
+ self._dataset.ids, self._dataset.start_steps)
+ ]
+ # infer clip end steps to set sampling distribution
+ self._dataset.end_steps = tuple(clip.end_step for clip in self._all_clips)
+ else:
+ self._all_clips = [None] * self._num_clips
+
+ def _add_observables(self, enabled_reference_observables):
+
+ # pylint: disable=g-long-lambda
+ self._walker.observables.add_observable(
+ 'reference_rel_joints',
+ base_observable.Generic(lambda _: self._reference_observations[
+ 'walker/reference_rel_joints']))
+ self._walker.observables.add_observable(
+ 'reference_rel_bodies_pos_global',
+ base_observable.Generic(lambda _: self._reference_observations[
+ 'walker/reference_rel_bodies_pos_global']))
+ self._walker.observables.add_observable(
+ 'reference_rel_bodies_quats',
+ base_observable.Generic(lambda _: self._reference_observations[
+ 'walker/reference_rel_bodies_quats']))
+ self._walker.observables.add_observable(
+ 'reference_rel_bodies_pos_local',
+ base_observable.Generic(lambda _: self._reference_observations[
+ 'walker/reference_rel_bodies_pos_local']))
+ self._walker.observables.add_observable(
+ 'reference_ego_bodies_quats',
+ base_observable.Generic(lambda _: self._reference_observations[
+ 'walker/reference_ego_bodies_quats']))
+ self._walker.observables.add_observable(
+ 'reference_rel_root_quat',
+ base_observable.Generic(lambda _: self._reference_observations[
+ 'walker/reference_rel_root_quat']))
+ self._walker.observables.add_observable(
+ 'reference_rel_root_pos_local',
+ base_observable.Generic(lambda _: self._reference_observations[
+ 'walker/reference_rel_root_pos_local']))
+ # pylint: enable=g-long-lambda
+ self._walker.observables.add_observable(
+ 'reference_appendages_pos',
+ base_observable.Generic(self.get_reference_appendages_pos))
+
+ if enabled_reference_observables:
+ for name, observable in self.observables.items():
+ observable.enabled = name in enabled_reference_observables
+ self._walker.observables.add_observable(
+ 'clip_id', base_observable.Generic(self.get_clip_id))
+ self._walker.observables.add_observable(
+ 'velocimeter_control', base_observable.Generic(self.get_veloc_control))
+ self._walker.observables.add_observable(
+ 'gyro_control', base_observable.Generic(self.get_gyro_control))
+ self._walker.observables.add_observable(
+ 'joints_vel_control',
+ base_observable.Generic(self.get_joints_vel_control))
+
+ self._arena.observables.add_observable(
+ 'reference_props_pos_global',
+ base_observable.Generic(self.get_reference_props_pos_global))
+ self._arena.observables.add_observable(
+ 'reference_props_quat_global',
+ base_observable.Generic(self.get_reference_props_quat_global))
+ observables = []
+ observables += self._walker.observables.proprioception
+ observables += self._walker.observables.kinematic_sensors
+ observables += self._walker.observables.dynamic_sensors
+
+ for observable in observables:
+ observable.enabled = True
+
+ for prop in self._props:
+ prop.observables.position.enabled = True
+ prop.observables.orientation.enabled = True
+
+ def _get_possible_starts(self):
+ # List all possible (clip, step) starting points.
+ self._possible_starts = []
+ self._start_probabilities = []
+ dataset = self._dataset
+ for clip_number, (start, end, weight) in enumerate(
+ zip(dataset.start_steps, dataset.end_steps, dataset.weights)):
+ # length - required lookahead - minimum number of steps
+ last_possible_start = end - self._max_ref_step - self._min_steps
+
+ if self._always_init_at_clip_start:
+ self._possible_starts += [(clip_number, start)]
+ self._start_probabilities += [weight]
+ else:
+ self._possible_starts += [
+ (clip_number, j) for j in range(start, last_possible_start)
+ ]
+ self._start_probabilities += [
+ weight for _ in range(start, last_possible_start)
+ ]
+
+ # normalize start probabilities
+ self._start_probabilities = np.array(self._start_probabilities) / np.sum(
+ self._start_probabilities)
+
+ def initialize_episode_mjcf(self, random_state: np.random.RandomState):
+ if hasattr(self._arena, 'regenerate'):
+ self._arena.regenerate(random_state)
+
+ # Get a new clip here to instantiate the right prop for this episode.
+ self._get_clip_to_track(random_state)
+ # Set up props.
+ # We call the prop factory here to ensure that props can change per episode.
+ for prop in self._props:
+ prop.detach()
+ del prop
+
+ if not self._disable_props:
+ self._props = self._current_clip.create_props(
+ prop_factory=self._prop_factory)
+ for prop in self._props:
+ self._arena.add_free_entity(prop)
+ prop.observables.position.enabled = True
+ prop.observables.orientation.enabled = True
+
+ if self._ghost_offset is not None:
+ for prop in self._ghost_props:
+ prop.detach()
+ del prop
+ self._ghost_props = self._current_clip.create_props(
+ prop_factory=self._ghost_prop_factory)
+ for prop in self._ghost_props:
+ self._arena.add_free_entity(prop)
+ prop.observables.disable_all()
+
+ def _get_clip_to_track(self, random_state: np.random.RandomState):
+ # Randomly select a starting point.
+ index = random_state.choice(
+ len(self._possible_starts), p=self._start_probabilities)
+ clip_index, start_step = self._possible_starts[index]
+
+ self._current_clip_index = clip_index
+ clip_id = self._dataset.ids[self._current_clip_index]
+
+ if self._all_clips[self._current_clip_index] is None:
+ # fetch selected trajectory
+ logging.info('Loading clip %s', clip_id)
+ self._all_clips[self._current_clip_index] = self._loader.get_trajectory(
+ clip_id,
+ start_step=self._dataset.start_steps[self._current_clip_index],
+ end_step=self._dataset.end_steps[self._current_clip_index],
+ zero_out_velocities=False)
+ self._current_clip = self._all_clips[self._current_clip_index]
+ self._clip_reference_features = self._current_clip.as_dict()
+ self._strip_reference_prefix()
+
+ # The reference features are already restricted to
+ # clip_start_step:clip_end_step. However start_step is in
+ # [clip_start_step:clip_end_step]. Hence we subtract clip_start_step to
+ # obtain a valid index for the reference features.
+ self._time_step = start_step - self._dataset.start_steps[
+ self._current_clip_index]
+ self._current_start_time = (start_step - self._dataset.start_steps[
+ self._current_clip_index]) * self._current_clip.dt
+ self._last_step = len(
+ self._clip_reference_features['joints']) - self._max_ref_step - 1
+ logging.info('Mocap %s at step %d with remaining length %d.', clip_id,
+ start_step, self._last_step - start_step)
+
+ def initialize_episode(self, physics: 'mjcf.Physics',
+ random_state: np.random.RandomState):
+ """Randomly selects a starting point and set the walker."""
+
+ # Set the walker at the beginning of the clip.
+ self._set_walker(physics)
+ self._walker_features = utils.get_features(
+ physics, self._walker, props=self._props)
+ self._walker_features_prev = self._walker_features.copy()
+
+ self._walker_joints = np.array(physics.bind(self._walker.mocap_joints).qpos) # pytype: disable=attribute-error
+
+ # compute initial error
+ self._compute_termination_error()
+ # assert error is 0 at initialization. In particular this will prevent
+ # a proto/walker mismatch.
+ if self._termination_error > 1e-2:
+ raise ValueError(('The termination exceeds 1e-2 at initialization. '
+ 'This is likely due to a proto/walker mismatch.'))
+
+ self._update_ghost(physics)
+ self._reference_observations.update(
+ self.get_all_reference_observations(physics))
+
+ # reset reward channels
+ self._reset_reward_channels()
+
+ def _reset_reward_channels(self):
+ if self._reward_keys:
+ self.last_reward_channels = collections.OrderedDict([
+ (k, 0.0) for k in self._reward_keys
+ ])
+ else:
+ self.last_reward_channels = None
+
+ def _compute_termination_error(self):
+ target_joints = self._clip_reference_features['joints'][self._time_step]
+ error_joints = np.mean(np.abs(target_joints - self._walker_joints))
+ target_bodies = self._clip_reference_features['body_positions'][
+ self._time_step]
+ error_bodies = np.mean(
+ np.abs((target_bodies -
+ self._walker_features['body_positions'])[self._body_idxs]))
+ self._termination_error = (
+ 0.5 * self._body_error_multiplier * error_bodies + 0.5 * error_joints)
+
+ if self._props:
+ target_props = self._clip_reference_features['prop_positions'][
+ self._time_step]
+ cur_props = self._walker_features['prop_positions']
+ # Separately compute prop termination error as euclidean distance.
+ self._prop_termination_error = np.mean(
+ np.linalg.norm(target_props - cur_props, axis=-1))
+
+ def before_step(self, physics: 'mjcf.Physics', action,
+ random_state: np.random.RandomState):
+ self._walker.apply_action(physics, action, random_state)
+
+ def after_step(self, physics: 'mjcf.Physics',
+ random_state: np.random.RandomState):
+ """Update the data after step."""
+ del random_state # unused by after_step.
+
+ self._walker_features_prev = self._walker_features.copy()
+
+ def after_compile(self, physics: 'mjcf.Physics',
+ random_state: np.random.RandomState):
+ # populate reference observations field to initialize observations.
+ if not self._reference_observations:
+ self._reference_observations.update(
+ self.get_all_reference_observations(physics))
+
+ def should_terminate_episode(self, physics: 'mjcf.Physics'):
+ del physics # physics unused by should_terminate_episode.
+
+ if self._should_truncate:
+ logging.info('Truncate with error %f.', self._termination_error)
+ return True
+
+ if self._end_mocap:
+ logging.info('End of mocap.')
+ return True
+
+ return False
+
+ def get_discount(self, physics: 'mjcf.Physics'):
+ del physics # unused by get_discount.
+
+ if self._should_truncate:
+ return 0.0
+ return 1.0
+
+ def get_reference_rel_joints(self, physics: 'mjcf.Physics'):
+ """Observation of the reference joints relative to walker."""
+ del physics # physics unused by reference observations.
+ time_steps = self._time_step + self._ref_steps
+ diff = (self._clip_reference_features['joints'][time_steps] -
+ self._walker_joints)
+ return diff[:, self._walker.mocap_to_observable_joint_order].flatten()
+
+ def get_reference_rel_bodies_pos_global(self, physics: 'mjcf.Physics'):
+ """Observation of the reference bodies relative to walker."""
+ del physics # physics unused by reference observations.
+
+ time_steps = self._time_step + self._ref_steps
+ return (self._clip_reference_features['body_positions'][time_steps] -
+ self._walker_features['body_positions'])[:,
+ self._body_idxs].flatten()
+
+ def get_reference_rel_bodies_quats(self, physics: 'mjcf.Physics'):
+ """Observation of the reference bodies quats relative to walker."""
+ del physics # physics unused by reference observations.
+
+ time_steps = self._time_step + self._ref_steps
+ obs = []
+ for t in time_steps:
+ for b in self._body_idxs:
+ obs.append(
+ tr.quat_diff(
+ self._walker_features['body_quaternions'][b, :],
+ self._clip_reference_features['body_quaternions'][t, b, :]))
+ return np.concatenate([o.flatten() for o in obs])
+
+ def get_reference_rel_bodies_pos_local(self, physics: 'mjcf.Physics'):
+ """Observation of the reference bodies relative to walker in local frame."""
+ time_steps = self._time_step + self._ref_steps
+ obs = self._walker.transform_vec_to_egocentric_frame(
+ physics, (self._clip_reference_features['body_positions'][time_steps] -
+ self._walker_features['body_positions'])[:, self._body_idxs])
+ return np.concatenate([o.flatten() for o in obs])
+
+ def get_reference_ego_bodies_quats(self, unused_physics: 'mjcf.Physics'):
+ """Body quat of the reference relative to the reference root quat."""
+ time_steps = self._time_step + self._ref_steps
+ obs = []
+ quats_for_clip = self._reference_ego_bodies_quats[self._current_clip_index]
+ for t in time_steps:
+ if t not in quats_for_clip:
+ root_quat = self._clip_reference_features['quaternion'][t, :]
+ quats_for_clip[t] = [
+ tr.quat_diff( # pylint: disable=g-complex-comprehension
+ root_quat,
+ self._clip_reference_features['body_quaternions'][t, b, :])
+ for b in self._body_idxs
+ ]
+ obs.extend(quats_for_clip[t])
+ return np.concatenate([o.flatten() for o in obs])
+
+ def get_reference_rel_root_quat(self, physics: 'mjcf.Physics'):
+ """Root quaternion of reference relative to current root quat."""
+ del physics # physics unused by reference observations.
+
+ time_steps = self._time_step + self._ref_steps
+ obs = []
+ for t in time_steps:
+ obs.append(
+ tr.quat_diff(self._walker_features['quaternion'],
+ self._clip_reference_features['quaternion'][t, :]))
+ return np.concatenate([o.flatten() for o in obs])
+
+ def get_reference_appendages_pos(self, physics: 'mjcf.Physics'):
+ """Reference appendage positions in reference frame."""
+ del physics # physics unused by reference observations.
+
+ time_steps = self._time_step + self._ref_steps
+ return self._clip_reference_features['appendages'][time_steps].flatten()
+
+ def get_reference_rel_root_pos_local(self, physics: 'mjcf.Physics'):
+ """Reference position relative to current root position in root frame."""
+ time_steps = self._time_step + self._ref_steps
+ obs = self._walker.transform_vec_to_egocentric_frame(
+ physics, (self._clip_reference_features['position'][time_steps] -
+ self._walker_features['position']))
+ return np.concatenate([o.flatten() for o in obs])
+
+ def get_reference_props_pos_global(self, physics: 'mjcf.Physics'):
+ time_steps = self._time_step + self._ref_steps
+ # size N x 3 where N = # of props
+ if self._props:
+ return self._clip_reference_features['prop_positions'][
+ time_steps].flatten()
+ else:
+ return []
+
+ def get_reference_props_quat_global(self, physics: 'mjcf.Physics'):
+ time_steps = self._time_step + self._ref_steps
+ # size N x 4 where N = # of props
+ if self._props:
+ return self._clip_reference_features['prop_quaternions'][
+ time_steps].flatten()
+ else:
+ return []
+
+ def get_veloc_control(self, physics: 'mjcf.Physics'):
+ """Velocity measurements in the prev root frame at the control timestep."""
+ del physics # physics unused by get_veloc_control.
+
+ rmat_prev = tr.quat_to_mat(self._walker_features_prev['quaternion'])[:3, :3]
+ veloc_world = (
+ self._walker_features['position'] -
+ self._walker_features_prev['position']) / self._control_timestep
+ return np.dot(veloc_world, rmat_prev)
+
+ def get_gyro_control(self, physics: 'mjcf.Physics'):
+ """Gyro measurements in the prev root frame at the control timestep."""
+ del physics # physics unused by get_gyro_control.
+
+ quat_curr, quat_prev = (self._walker_features['quaternion'],
+ self._walker_features_prev['quaternion'])
+ normed_diff = tr.quat_diff(quat_prev, quat_curr)
+ normed_diff /= np.linalg.norm(normed_diff)
+ return tr.quat_to_axisangle(normed_diff) / self._control_timestep
+
+ def get_joints_vel_control(self, physics: 'mjcf.Physics'):
+ """Joint velocity measurements at the control timestep."""
+ del physics # physics unused by get_joints_vel_control.
+
+ joints_curr, joints_prev = (self._walker_features['joints'],
+ self._walker_features_prev['joints'])
+ return (joints_curr - joints_prev)[
+ self._walker.mocap_to_observable_joint_order]/self._control_timestep
+
+ def get_clip_id(self, physics: 'mjcf.Physics'):
+ """Observation of the clip id."""
+ del physics # physics unused by get_clip_id.
+
+ return np.array([self._current_clip_index])
+
+ def get_all_reference_observations(self, physics: 'mjcf.Physics'):
+ reference_observations = dict()
+ reference_observations[
+ 'walker/reference_rel_bodies_pos_local'] = self.get_reference_rel_bodies_pos_local(
+ physics)
+ reference_observations[
+ 'walker/reference_rel_joints'] = self.get_reference_rel_joints(physics)
+ reference_observations[
+ 'walker/reference_rel_bodies_pos_global'] = self.get_reference_rel_bodies_pos_global(
+ physics)
+ reference_observations[
+ 'walker/reference_ego_bodies_quats'] = self.get_reference_ego_bodies_quats(
+ physics)
+ reference_observations[
+ 'walker/reference_rel_root_quat'] = self.get_reference_rel_root_quat(
+ physics)
+ reference_observations[
+ 'walker/reference_rel_bodies_quats'] = self.get_reference_rel_bodies_quats(
+ physics)
+ reference_observations[
+ 'walker/reference_rel_root_pos_local'] = self.get_reference_rel_root_pos_local(
+ physics)
+ if self._props:
+ reference_observations[
+ 'props/reference_pos_global'] = self.get_reference_props_pos_global(
+ physics)
+ reference_observations[
+ 'props/reference_quat_global'] = self.get_reference_props_quat_global(
+ physics)
+ return reference_observations
+
+ def get_reward(self, physics: 'mjcf.Physics') -> float:
+ reward, unused_debug_outputs, reward_channels = self._reward_fn(
+ termination_error=self._termination_error,
+ termination_error_threshold=self._termination_error_threshold,
+ reference_features=self._current_reference_features,
+ walker_features=self._walker_features,
+ reference_observations=self._reference_observations)
+
+ if 'actuator_force' in self._reward_keys:
+ reward_channels['actuator_force'] = -self._actuator_force_coeff*np.mean(
+ np.square(self._walker.actuator_force(physics)))
+
+ self._should_truncate = self._termination_error > self._termination_error_threshold
+
+ if self._props:
+ prop_termination = self._prop_termination_error > self._prop_termination_error_threshold
+ self._should_truncate = self._should_truncate or prop_termination
+
+ self.last_reward_channels = reward_channels
+ return reward
+
+ def _set_walker(self, physics: 'mjcf.Physics'):
+ timestep_features = tree.map_structure(lambda x: x[self._time_step],
+ self._clip_reference_features)
+ utils.set_walker_from_features(physics, self._walker, timestep_features)
+ if self._props:
+ utils.set_props_from_features(physics, self._props, timestep_features)
+ mjlib.mj_kinematics(physics.model.ptr, physics.data.ptr)
+
+ def _update_ghost(self, physics: 'mjcf.Physics'):
+ if self._ghost_offset is not None:
+ target = tree.map_structure(lambda x: x[self._time_step],
+ self._clip_reference_features)
+ utils.set_walker_from_features(physics, self._ghost, target,
+ self._ghost_offset)
+ if self._ghost_props:
+ utils.set_props_from_features(
+ physics, self._ghost_props, target, z_offset=self._ghost_offset)
+ mjlib.mj_kinematics(physics.model.ptr, physics.data.ptr)
+
+ def action_spec(self, physics: 'mjcf.Physics'):
+ """Action spec of the walker only."""
+ ctrl = physics.bind(self._walker.actuators).ctrl # pytype: disable=attribute-error
+ shape = ctrl.shape
+ dtype = ctrl.dtype
+ minimum = []
+ maximum = []
+ for actuator in self._walker.actuators:
+ if physics.bind(actuator).ctrllimited: # pytype: disable=attribute-error
+ ctrlrange = physics.bind(actuator).ctrlrange # pytype: disable=attribute-error
+ minimum.append(ctrlrange[0])
+ maximum.append(ctrlrange[1])
+ else:
+ minimum.append(-float('inf'))
+ maximum.append(float('inf'))
+ return specs.BoundedArray(
+ shape=shape,
+ dtype=dtype,
+ minimum=np.asarray(minimum, dtype=dtype),
+ maximum=np.asarray(maximum, dtype=dtype),
+ name='\t'.join(actuator.full_identifier # pytype: disable=attribute-error
+ for actuator in self._walker.actuators))
+
+ @property
+ @abc.abstractmethod
+ def name(self):
+ raise NotImplementedError
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+
+class MultiClipMocapTracking(ReferencePosesTask):
+ """Task for multi-clip mocap tracking."""
+
+ def __init__(
+ self,
+ walker: Callable[..., 'legacy_base.Walker'],
+ arena: composer.Arena,
+ ref_path: Text,
+ ref_steps: Sequence[int],
+ dataset: Union[Text, Sequence[Any]],
+ termination_error_threshold: float = 0.3,
+ prop_termination_error_threshold: float = 0.1,
+ min_steps: int = 10,
+ reward_type: Text = 'termination_reward',
+ physics_timestep: float = DEFAULT_PHYSICS_TIMESTEP,
+ always_init_at_clip_start: bool = False,
+ proto_modifier: Optional[Any] = None,
+ prop_factory: Optional[Any] = None,
+ disable_props: bool = True,
+ ghost_offset: Optional[Sequence[Union[int, float]]] = None,
+ body_error_multiplier: Union[int, float] = 1.0,
+ actuator_force_coeff: float = 0.015,
+ enabled_reference_observables: Optional[Sequence[Text]] = None,
+ ):
+ """Mocap tracking task.
+
+ Args:
+ walker: Walker constructor to be used.
+ arena: Arena to be used.
+ ref_path: Path to the dataset containing reference poses.
+ ref_steps: tuples of indices of reference observation. E.g if
+ ref_steps=(1, 2, 3) the walker/reference observation at time t will
+ contain information from t+1, t+2, t+3.
+ dataset: dataset: A ClipCollection instance or a named dataset that
+ appears as a key in DATASETS in datasets.py
+ termination_error_threshold: Error threshold for episode terminations for
+ hand body position and joint error only.
+ prop_termination_error_threshold: Error threshold for episode terminations
+ for prop position.
+ min_steps: minimum number of steps within an episode. This argument
+ determines the latest allowable starting point within a given reference
+ trajectory.
+ reward_type: type of reward to use, must be a string that appears as a key
+ in the REWARD_FN dict in rewards.py.
+ physics_timestep: Physics timestep to use for simulation.
+ always_init_at_clip_start: only initialize epsidodes at the start of a
+ reference trajectory.
+ proto_modifier: Optional proto modifier to modify reference trajectories,
+ e.g. adding a vertical offset.
+ prop_factory: Optional function that takes the mocap proto and returns
+ the corresponding props for the trajectory.
+ disable_props: If prop_factory is specified but disable_props is True,
+ no props will be created.
+ ghost_offset: if not None, include a ghost rendering of the walker with
+ the reference pose at the specified position offset.
+ body_error_multiplier: A multiplier that is applied to the body error term
+ when determining failure termination condition.
+ actuator_force_coeff: A coefficient for the actuator force reward channel.
+ enabled_reference_observables: Optional iterable of enabled observables.
+ If not specified, a reasonable default set will be enabled.
+ """
+ super().__init__(
+ walker=walker,
+ arena=arena,
+ ref_path=ref_path,
+ ref_steps=ref_steps,
+ termination_error_threshold=termination_error_threshold,
+ prop_termination_error_threshold=prop_termination_error_threshold,
+ min_steps=min_steps,
+ dataset=dataset,
+ reward_type=reward_type,
+ physics_timestep=physics_timestep,
+ always_init_at_clip_start=always_init_at_clip_start,
+ proto_modifier=proto_modifier,
+ prop_factory=prop_factory,
+ disable_props=disable_props,
+ ghost_offset=ghost_offset,
+ body_error_multiplier=body_error_multiplier,
+ actuator_force_coeff=actuator_force_coeff,
+ enabled_reference_observables=enabled_reference_observables)
+ self._walker.observables.add_observable(
+ 'time_in_clip',
+ base_observable.Generic(self.get_normalized_time_in_clip))
+
+ def after_step(self, physics: 'mjcf.Physics', random_state):
+ """Update the data after step."""
+ super().after_step(physics, random_state)
+ self._time_step += 1
+
+ # Update the walker's data for this timestep.
+ self._walker_features = utils.get_features(
+ physics, self._walker, props=self._props)
+ # features for default error
+ self._walker_joints = np.array(physics.bind(self._walker.mocap_joints).qpos) # pytype: disable=attribute-error
+
+ self._current_reference_features = {
+ k: v[self._time_step].copy()
+ for k, v in self._clip_reference_features.items()
+ }
+
+ # Error.
+ self._compute_termination_error()
+
+ # Terminate based on the error.
+ self._end_mocap = self._time_step == self._last_step
+
+ self._reference_observations.update(
+ self.get_all_reference_observations(physics))
+
+ self._update_ghost(physics)
+
+ def get_normalized_time_in_clip(self, physics: 'mjcf.Physics'):
+ """Observation of the normalized time in the mocap clip."""
+ normalized_time_in_clip = (self._current_start_time +
+ physics.time()) / self._current_clip.duration
+ return np.array([normalized_time_in_clip])
+
+ @property
+ def name(self):
+ return 'MultiClipMocapTracking'
+
+
+class PlaybackTask(ReferencePosesTask):
+ """Simple task to visualize mocap data."""
+
+ def __init__(self,
+ walker,
+ arena,
+ ref_path: Text,
+ dataset: Union[Text, types.ClipCollection],
+ proto_modifier: Optional[Any] = None,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP):
+ super().__init__(walker=walker,
+ arena=arena,
+ ref_path=ref_path,
+ ref_steps=(1,),
+ dataset=dataset,
+ termination_error_threshold=np.inf,
+ physics_timestep=physics_timestep,
+ always_init_at_clip_start=True,
+ proto_modifier=proto_modifier)
+ self._current_clip_index = -1
+
+ def _get_clip_to_track(self, random_state: np.random.RandomState):
+ self._current_clip_index = (self._current_clip_index + 1) % self._num_clips
+
+ start_step = self._dataset.start_steps[self._current_clip_index]
+ clip_id = self._dataset.ids[self._current_clip_index]
+ logging.info('Showing clip %d of %d, clip id %s',
+ self._current_clip_index+1, self._num_clips, clip_id)
+
+ if self._all_clips[self._current_clip_index] is None:
+ # fetch selected trajectory
+ logging.info('Loading clip %s', clip_id)
+ self._all_clips[self._current_clip_index] = self._loader.get_trajectory(
+ clip_id,
+ start_step=self._dataset.start_steps[self._current_clip_index],
+ end_step=self._dataset.end_steps[self._current_clip_index],
+ zero_out_velocities=False)
+ self._current_clip = self._all_clips[self._current_clip_index]
+ self._clip_reference_features = self._current_clip.as_dict()
+ self._clip_reference_features = _strip_reference_prefix(
+ self._clip_reference_features, 'walker/')
+ # The reference features are already restricted to
+ # clip_start_step:clip_end_step. However start_step is in
+ # [clip_start_step:clip_end_step]. Hence we subtract clip_start_step to
+ # obtain a valid index for the reference features.
+ self._time_step = start_step - self._dataset.start_steps[
+ self._current_clip_index]
+ self._current_start_time = (start_step - self._dataset.start_steps[
+ self._current_clip_index]) * self._current_clip.dt
+ self._last_step = len(
+ self._clip_reference_features['joints']) - self._max_ref_step - 1
+ logging.info('Mocap %s at step %d with remaining length %d.', clip_id,
+ start_step, self._last_step - start_step)
+
+ def _set_walker(self, physics: 'mjcf.Physics'):
+ timestep_features = tree.map_structure(lambda x: x[self._time_step],
+ self._clip_reference_features)
+ utils.set_walker_from_features(physics, self._walker, timestep_features)
+ mjlib.mj_kinematics(physics.model.ptr, physics.data.ptr)
+
+ def after_step(self, physics, random_state: np.random.RandomState):
+ super().after_step(physics, random_state)
+ self._time_step += 1
+
+ self._set_walker(physics)
+ self._end_mocap = self._time_step == self._last_step
+
+ def get_reward(self, physics):
+ return 0.0
+
+ @property
+ def name(self):
+ return 'PlaybackTask'
diff --git a/dm_control/locomotion/tasks/reference_pose/tracking_test.py b/dm_control/locomotion/tasks/reference_pose/tracking_test.py
new file mode 100644
index 00000000..83a08828
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/tracking_test.py
@@ -0,0 +1,321 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for mocap tracking."""
+
+import os
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control import composer
+from dm_control.locomotion import arenas
+from dm_control.locomotion import walkers
+from dm_control.locomotion.mocap import props
+from dm_control.locomotion.tasks.reference_pose import tracking
+from dm_control.locomotion.tasks.reference_pose import types
+
+import numpy as np
+
+from dm_control.utils import io as resources
+
+TEST_FILE_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), '../../mocap'))
+TEST_FILE_PATH = os.path.join(TEST_FILE_DIR, 'test_trajectories.h5')
+
+REFERENCE_PROP_KEYS = [
+ f'reference_props_{key}_global' for key in ['pos', 'quat']
+]
+PROP_OBSERVATION_KEYS = [
+ f'cmuv2019_box/{key}' for key in ['position', 'orientation']
+]
+N_PROPS = 1
+GHOST_OFFSET = np.array((0, 0, 0.1))
+
+
+class MultiClipMocapTrackingTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+
+ self.walker = walkers.CMUHumanoidPositionControlled
+ def _make_wrong_walker(name):
+ return walkers.CMUHumanoidPositionControlled(
+ include_face=False, model_version='2020', scale_default=True,
+ name=name)
+ self.wrong_walker = _make_wrong_walker
+ self.arena = arenas.Floor()
+ self.test_data = resources.GetResourceFilename(TEST_FILE_PATH)
+
+ @parameterized.named_parameters(('termination_reward', 'termination_reward'),
+ ('comic', 'comic'))
+ def test_initialization_and_step(self, reward):
+
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(1, 2, 3, 4, 5),
+ min_steps=1,
+ reward_type=reward,
+ )
+
+ env = composer.Environment(task=task)
+
+ env.reset()
+
+ # check no task error after episode init before first step
+ self.assertLess(task._termination_error, 1e-3)
+
+ action_spec = env.action_spec()
+ env.step(np.zeros(action_spec.shape))
+
+ @parameterized.named_parameters(('first_clip', 0), ('second_clip', 1))
+ def test_clip_weights(self, clip_number):
+ # test whether clip weights work correctly if ids are not specified.
+
+ clip_weights = (1, 0) if clip_number == 0 else (0, 1)
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ ref_steps=(1, 2, 3, 4, 5),
+ min_steps=1,
+ dataset=types.ClipCollection(
+ ids=('cmuv2019_001', 'cmuv2019_002'), weights=clip_weights),
+ reward_type='comic',
+ )
+
+ env = composer.Environment(task=task)
+
+ env.reset()
+
+ self.assertEqual(task._current_clip.identifier,
+ task._dataset.ids[clip_number])
+
+ @parameterized.named_parameters(
+ ('start_step_id_length_mismatch_explicit_id', (0,), (10, 10), (1, 1)),
+ ('end_step_id_length_mismatch_explicit_id', (0, 0), (10,), (1, 1)),
+ ('clip_weights_id_length_mismatch_explicit_id', (0, 0), (10, 10), (1,)),
+ )
+ def test_task_validation(self, clip_start_steps, clip_end_steps,
+ clip_weights):
+ # test whether task construction fails with invalid arguments.
+ with self.assertRaisesRegex(ValueError, 'ClipCollection'):
+ unused_task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ ref_steps=(1, 2, 3, 4, 5),
+ min_steps=1,
+ dataset=types.ClipCollection(
+ ids=('cmuv2019_001', 'cmuv2019_002'),
+ start_steps=clip_start_steps,
+ end_steps=clip_end_steps,
+ weights=clip_weights),
+ reward_type='comic',
+ )
+
+ def test_init_at_clip_start(self):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(
+ ids=('cmuv2019_001', 'cmuv2019_002'),
+ start_steps=(2, 0),
+ end_steps=(10, 10)),
+ ref_steps=(1, 2, 3, 4, 5),
+ min_steps=1,
+ reward_type='termination_reward',
+ always_init_at_clip_start=True,
+ )
+ self.assertEqual(task._possible_starts, [(0, 2), (1, 0)])
+
+ def test_failure_with_wrong_walker(self):
+ with self.assertRaisesRegex(ValueError, 'proto/walker'):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.wrong_walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ ref_steps=(1, 2, 3, 4, 5),
+ min_steps=1,
+ dataset=types.ClipCollection(
+ ids=('cmuv2019_001', 'cmuv2019_002'),
+ start_steps=(0, 0),
+ end_steps=(10, 10)),
+ reward_type='comic',
+ )
+
+ env = composer.Environment(task=task)
+
+ env.reset()
+
+ def test_enabled_reference_observables(self):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(1, 2, 3, 4, 5),
+ min_steps=1,
+ reward_type='comic',
+ enabled_reference_observables=('walker/reference_rel_joints',)
+ )
+
+ env = composer.Environment(task=task)
+
+ timestep = env.reset()
+
+ self.assertIn('walker/reference_rel_joints', timestep.observation.keys())
+ self.assertNotIn('walker/reference_rel_root_pos_local',
+ timestep.observation.keys())
+
+ # check that all desired observables are enabled.
+ desired_observables = []
+ desired_observables += task._walker.observables.proprioception
+ desired_observables += task._walker.observables.kinematic_sensors
+ desired_observables += task._walker.observables.dynamic_sensors
+
+ for observable in desired_observables:
+ self.assertTrue(observable.enabled)
+
+ def test_prop_factory(self):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(0,),
+ min_steps=1,
+ disable_props=False,
+ prop_factory=props.Prop,
+ )
+ env = composer.Environment(task=task)
+
+ observation = env.reset().observation
+ # Test the expected prop observations exist and have the expected size.
+ dims = [3, 4]
+ for key, dim in zip(REFERENCE_PROP_KEYS, dims):
+ self.assertIn(key, task.observables)
+ self.assertSequenceEqual(observation[key].shape, (N_PROPS, dim))
+
+ # Since no ghost offset was specified, test that there are no ghost props.
+ self.assertEmpty(task._ghost_props)
+
+ # Test that props go to the expected location on reset.
+ for ref_key, obs_key in zip(REFERENCE_PROP_KEYS, PROP_OBSERVATION_KEYS):
+ np.testing.assert_array_almost_equal(
+ observation[ref_key], observation[obs_key]
+ )
+
+ def test_ghost_prop(self):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(0,),
+ min_steps=1,
+ disable_props=False,
+ prop_factory=props.Prop,
+ ghost_offset=GHOST_OFFSET,
+ )
+ env = composer.Environment(task=task)
+
+ # Test that the ghost props are present when ghost_offset specified.
+ self.assertLen(task._ghost_props, N_PROPS)
+
+ # Test that the ghost prop tracks the goal trajectory after step.
+ env.reset()
+ observation = env.step(env.action_spec().generate_value()).observation
+ ghost_pos, ghost_quat = task._ghost_props[0].get_pose(env.physics)
+ goal_pos, goal_quat = (
+ np.squeeze(observation[key]) for key in REFERENCE_PROP_KEYS)
+
+ np.testing.assert_array_equal(np.array(ghost_pos), goal_pos + GHOST_OFFSET)
+ np.testing.assert_array_almost_equal(ghost_quat, goal_quat)
+
+ def test_disable_props(self):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(0,),
+ min_steps=1,
+ prop_factory=props.Prop,
+ disable_props=True,
+ )
+ env = composer.Environment(task=task)
+
+ observation = env.reset().observation
+ # Test that the prop observations are empty.
+ for key in REFERENCE_PROP_KEYS:
+ self.assertIn(key, task.observables)
+ self.assertSequenceEqual(observation[key].shape, (1, 0))
+ # Test that the props and ghost props are not constructed.
+ self.assertEmpty(task._props)
+ self.assertEmpty(task._ghost_props)
+
+ def test_prop_termination(self):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(0,),
+ min_steps=1,
+ disable_props=False,
+ prop_factory=props.Prop,
+ )
+ env = composer.Environment(task=task)
+ observation = env.reset().observation
+
+ # Test that prop position contributes to prop termination error.
+ task._set_walker(env.physics)
+ wrong_position = observation[REFERENCE_PROP_KEYS[0]] + np.ones(3)
+ task._props[0].set_pose(env.physics, wrong_position)
+ task.after_step(env.physics, 0)
+ task._compute_termination_error()
+ self.assertGreater(task._prop_termination_error, 0.)
+ task.get_reward(env.physics)
+ self.assertEqual(task._should_truncate, True)
+
+ def test_ghost_walker(self):
+ task = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(0,),
+ min_steps=1,
+ ghost_offset=None,
+ )
+ env = composer.Environment(task=task)
+ task_with_ghost = tracking.MultiClipMocapTracking(
+ walker=self.walker,
+ arena=self.arena,
+ ref_path=self.test_data,
+ dataset=types.ClipCollection(ids=('cmuv2019_001', 'cmuv2019_002')),
+ ref_steps=(0,),
+ min_steps=1,
+ ghost_offset=GHOST_OFFSET,
+ )
+ env_with_ghost = composer.Environment(task=task_with_ghost)
+ # Test that the ghost does not introduce additional actions.
+ self.assertEqual(env_with_ghost.action_spec(), env.action_spec())
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/tasks/reference_pose/types.py b/dm_control/locomotion/tasks/reference_pose/types.py
new file mode 100644
index 00000000..278b27ca
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/types.py
@@ -0,0 +1,55 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Types for reference pose tasks.
+"""
+
+from typing import Optional, Sequence, Text, Union
+
+import numpy as np
+
+
+class ClipCollection:
+ """Dataclass representing a collection of mocap reference clips."""
+
+ def __init__(self,
+ ids: Sequence[Text],
+ start_steps: Optional[Sequence[int]] = None,
+ end_steps: Optional[Sequence[int]] = None,
+ weights: Optional[Sequence[Union[int, float]]] = None):
+ """Instantiate a ClipCollection."""
+ self.ids = ids
+ self.start_steps = start_steps
+ self.end_steps = end_steps
+ self.weights = weights
+ num_clips = len(self.ids)
+ try:
+ if self.start_steps is None:
+ # by default start at the beginning
+ self.start_steps = (0,) * num_clips
+ else:
+ assert len(self.start_steps) == num_clips
+
+ # without access to the actual clip we cannot specify an end_steps default
+ if self.end_steps is not None:
+ assert len(self.end_steps) == num_clips
+
+ if self.weights is None:
+ self.weights = (1.0,) * num_clips
+ else:
+ assert len(self.weights) == num_clips
+ assert np.all(np.array(self.weights) >= 0.)
+ except AssertionError as e:
+ raise ValueError("ClipCollection validation failed. {}".format(e))
+
diff --git a/dm_control/locomotion/tasks/reference_pose/utils.py b/dm_control/locomotion/tasks/reference_pose/utils.py
new file mode 100644
index 00000000..1db02ccd
--- /dev/null
+++ b/dm_control/locomotion/tasks/reference_pose/utils.py
@@ -0,0 +1,170 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Utils for reference pose tasks."""
+
+from dm_control import mjcf
+from dm_control.utils import transformations as tr
+import numpy as np
+
+
+def add_walker(walker_fn, arena, name='walker', ghost=False, visible=True,
+ position=(0, 0, 0)):
+ """Create a walker."""
+ walker = walker_fn(name=name)
+
+ if ghost:
+ # if the walker has a built-in tracking light remove it.
+ light = walker.mjcf_model.find('light', 'tracking_light')
+ if light:
+ light.remove()
+
+ # Remove the contacts.
+ for geom in walker.mjcf_model.find_all('geom'):
+ # alpha=0.999 ensures grey ghost reference.
+ # for alpha=1.0 there is no visible difference between real walker and
+ # ghost reference.
+ alpha = 0.999
+ if geom.rgba is not None and geom.rgba[3] < alpha:
+ alpha = geom.rgba[3]
+
+ geom.set_attributes(
+ contype=0,
+ conaffinity=0,
+ rgba=(0.5, 0.5, 0.5, alpha if visible else 0.0))
+
+ # We don't want ghost actuators to be controllable, so remove them.
+ model = walker.mjcf_model
+
+ elems = model.find_all('actuator')
+ sensors = [x for x in model.find_all('sensor') if 'actuator' in x.tag]
+ elems += sensors
+
+ for elem in elems:
+ elem.remove()
+
+ skin = walker.mjcf_model.find('skin', 'skin')
+ if skin:
+ if visible:
+ skin.set_attributes(rgba=(0.5, 0.5, 0.5, 0.999))
+ else:
+ skin.set_attributes(rgba=(0.5, 0.5, 0.5, 0.))
+
+ if position == (0, 0, 0):
+ walker.create_root_joints(arena.attach(walker))
+ else:
+ spawn_site = arena.mjcf_model.worldbody.add('site', pos=position)
+ walker.create_root_joints(arena.attach(walker, spawn_site))
+ spawn_site.remove()
+
+ return walker
+
+
+def get_qpos_qvel_from_features(features):
+ """Get qpos and qvel from logged features to set walker."""
+ full_qpos = np.hstack([
+ features['position'],
+ features['quaternion'],
+ features['joints'],
+ ])
+ full_qvel = np.hstack([
+ features['velocity'],
+ features['angular_velocity'],
+ features['joints_velocity'],
+ ])
+ return full_qpos, full_qvel
+
+
+def set_walker_from_features(physics, walker, features, offset=0):
+ """Set the freejoint and walker's joints angles and velocities."""
+ qpos, qvel = get_qpos_qvel_from_features(features)
+ set_walker(physics, walker, qpos, qvel, offset=offset)
+
+
+def set_walker(physics, walker, qpos, qvel, offset=0, null_xyz_and_yaw=False,
+ position_shift=None, rotation_shift=None):
+ """Set the freejoint and walker's joints angles and velocities."""
+ qpos = np.array(qpos)
+ if null_xyz_and_yaw:
+ qpos[:2] = 0.
+ euler = tr.quat_to_euler(qpos[3:7], ordering='ZYX')
+ euler[0] = 0.
+ quat = tr.euler_to_quat(euler, ordering='ZYX')
+ qpos[3:7] = quat
+ qpos[:3] += offset
+
+ freejoint = mjcf.get_attachment_frame(walker.mjcf_model).freejoint
+
+ physics.bind(freejoint).qpos = qpos[:7]
+ physics.bind(freejoint).qvel = qvel[:6]
+
+ physics.bind(walker.mocap_joints).qpos = qpos[7:]
+ physics.bind(walker.mocap_joints).qvel = qvel[6:]
+ if position_shift is not None or rotation_shift is not None:
+ walker.shift_pose(physics, position=position_shift,
+ quaternion=rotation_shift, rotate_velocity=True)
+
+
+def set_props_from_features(physics, props, features, z_offset=0):
+ positions = features['prop_positions']
+ quaternions = features['prop_quaternions']
+ if np.isscalar(z_offset):
+ z_offset = np.array([0., 0., z_offset])
+ for prop, pos, quat in zip(props, positions, quaternions):
+ prop.set_pose(physics, pos + z_offset, quat)
+
+
+def get_features(physics, walker, props=None):
+ """Get walker features for reward functions."""
+ walker_bodies = walker.mocap_tracking_bodies
+
+ walker_features = {}
+ root_pos, root_quat = walker.get_pose(physics)
+ walker_features['position'] = np.array(root_pos)
+ walker_features['quaternion'] = np.array(root_quat)
+ joints = np.array(physics.bind(walker.mocap_joints).qpos)
+ walker_features['joints'] = joints
+ freejoint_frame = mjcf.get_attachment_frame(walker.mjcf_model)
+
+ com = np.array(physics.bind(freejoint_frame).subtree_com)
+ walker_features['center_of_mass'] = com
+ end_effectors = np.array(
+ walker.observables.end_effectors_pos(physics)[:]).reshape(-1, 3)
+ walker_features['end_effectors'] = end_effectors
+ if hasattr(walker.observables, 'appendages_pos'):
+ appendages = np.array(
+ walker.observables.appendages_pos(physics)[:]).reshape(-1, 3)
+ else:
+ appendages = np.array(end_effectors)
+ walker_features['appendages'] = appendages
+ xpos = np.array(physics.bind(walker_bodies).xpos)
+ walker_features['body_positions'] = xpos
+ xquat = np.array(physics.bind(walker_bodies).xquat)
+ walker_features['body_quaternions'] = xquat
+ root_vel, root_angvel = walker.get_velocity(physics)
+ walker_features['velocity'] = np.array(root_vel)
+ walker_features['angular_velocity'] = np.array(root_angvel)
+ joints_vel = np.array(physics.bind(walker.mocap_joints).qvel)
+ walker_features['joints_velocity'] = joints_vel
+
+ if props:
+ positions = []
+ quaternions = []
+ for prop in props:
+ pos, quat = prop.get_pose(physics)
+ positions.append(pos)
+ quaternions.append(quat)
+ walker_features['prop_positions'] = np.array(positions)
+ walker_features['prop_quaternions'] = np.array(quaternions)
+ return walker_features
diff --git a/dm_control/locomotion/walkers/__init__.py b/dm_control/locomotion/walkers/__init__.py
new file mode 100644
index 00000000..48c8fe11
--- /dev/null
+++ b/dm_control/locomotion/walkers/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Walkers for Locomotion tasks."""
+
+from dm_control.locomotion.walkers.ant import Ant
+from dm_control.locomotion.walkers.cmu_humanoid import CMUHumanoidPositionControlled
+from dm_control.locomotion.walkers.cmu_humanoid import CMUHumanoidPositionControlledV2020
+# Import removed.
+from dm_control.locomotion.walkers.jumping_ball import JumpingBallWithHead
+from dm_control.locomotion.walkers.jumping_ball import RollingBallWithHead
+from dm_control.locomotion.walkers.rodent import Rat
diff --git a/dm_control/locomotion/walkers/ant.py b/dm_control/locomotion/walkers/ant.py
new file mode 100644
index 00000000..22d6964f
--- /dev/null
+++ b/dm_control/locomotion/walkers/ant.py
@@ -0,0 +1,207 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A quadruped "ant" walker."""
+
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import legacy_base
+import numpy as np
+
+_XML_DIRNAME = os.path.join(os.path.dirname(__file__), '../../third_party/ant')
+_XML_FILENAME = 'ant.xml'
+
+
+class Ant(legacy_base.Walker):
+ """A quadruped "Ant" walker."""
+
+ def _build(self, name='walker', marker_rgba=None, initializer=None):
+ """Build an Ant walker.
+
+ Args:
+ name: name of the walker.
+ marker_rgba: (Optional) color the ant's front legs with marker_rgba.
+ initializer: (Optional) A `WalkerInitializer` object.
+ """
+ super()._build(initializer=initializer)
+ self._appendages_sensors = []
+ self._bodies_pos_sensors = []
+ self._bodies_quats_sensors = []
+ self._mjcf_root = mjcf.from_path(os.path.join(_XML_DIRNAME, _XML_FILENAME))
+ if name:
+ self._mjcf_root.model = name
+
+ # Set corresponding marker color if specified.
+ if marker_rgba is not None:
+ for geom in self.marker_geoms:
+ geom.set_attributes(rgba=marker_rgba)
+
+ # Initialize previous action.
+ self._prev_action = np.zeros(shape=self.action_spec.shape,
+ dtype=self.action_spec.dtype)
+
+ def initialize_episode(self, physics, random_state):
+ self._prev_action = np.zeros_like(self._prev_action)
+
+ def apply_action(self, physics, action, random_state):
+ super().apply_action(physics, action, random_state)
+
+ # Updates previous action.
+ self._prev_action[:] = action
+
+ def _build_observables(self):
+ return AntObservables(self)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def upright_pose(self):
+ return base.WalkerPose()
+
+ @property
+ def marker_geoms(self):
+ return [self._mjcf_root.find('geom', 'front_left_leg_geom'),
+ self._mjcf_root.find('geom', 'front_right_leg_geom')]
+
+ @composer.cached_property
+ def actuators(self):
+ return self._mjcf_root.find_all('actuator')
+
+ @composer.cached_property
+ def root_body(self):
+ return self._mjcf_root.find('body', 'torso')
+
+ @composer.cached_property
+ def bodies(self):
+ return tuple(self.mjcf_model.find_all('body'))
+
+ @composer.cached_property
+ def mocap_tracking_bodies(self):
+ """Collection of bodies for mocap tracking."""
+ return tuple(self.mjcf_model.find_all('body'))
+
+ @property
+ def mocap_joints(self):
+ return self.mjcf_model.find_all('joint')
+
+ @property
+ def _foot_bodies(self):
+ return (self._mjcf_root.find('body', 'front_left_foot'),
+ self._mjcf_root.find('body', 'front_right_foot'),
+ self._mjcf_root.find('body', 'back_right_foot'),
+ self._mjcf_root.find('body', 'back_left_foot'))
+
+ @composer.cached_property
+ def end_effectors(self):
+ return self._foot_bodies
+
+ @composer.cached_property
+ def observable_joints(self):
+ return [actuator.joint for actuator in self.actuators] # pylint: disable=not-an-iterable
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ def aliveness(self, physics):
+ return (physics.bind(self.root_body).xmat[-1] - 1.) / 2.
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ foot_geoms = []
+ for foot in self._foot_bodies:
+ foot_geoms.extend(foot.find_all('geom'))
+ return tuple(foot_geoms)
+
+ @property
+ def prev_action(self):
+ return self._prev_action
+
+ @property
+ def appendages_sensors(self):
+ return self._appendages_sensors
+
+ @property
+ def bodies_pos_sensors(self):
+ return self._bodies_pos_sensors
+
+ @property
+ def bodies_quats_sensors(self):
+ return self._bodies_quats_sensors
+
+
+class AntObservables(legacy_base.WalkerObservables):
+ """Observables for the Ant."""
+
+ @composer.observable
+ def appendages_pos(self):
+ """Equivalent to `end_effectors_pos` with the head's position appended."""
+ appendages = self._entity.end_effectors
+ self._entity.appendages_sensors[:] = []
+ for body in appendages:
+ self._entity.appendages_sensors.append(
+ self._entity.mjcf_model.sensor.add(
+ 'framepos', name=body.name + '_appendage',
+ objtype='xbody', objname=body,
+ reftype='xbody', refname=self._entity.root_body))
+ def appendages_ego_pos(physics):
+ return np.reshape(
+ physics.bind(self._entity.appendages_sensors).sensordata, -1)
+ return observable.Generic(appendages_ego_pos)
+
+ @composer.observable
+ def bodies_quats(self):
+ """Orientations of the bodies as quaternions, in the egocentric frame."""
+ bodies = self._entity.bodies
+ self._entity.bodies_quats_sensors[:] = []
+ for body in bodies:
+ self._entity.bodies_quats_sensors.append(
+ self._entity.mjcf_model.sensor.add(
+ 'framequat', name=body.name + '_ego_body_quat',
+ objtype='xbody', objname=body,
+ reftype='xbody', refname=self._entity.root_body))
+ def bodies_ego_orientation(physics):
+ return np.reshape(
+ physics.bind(self._entity.bodies_quats_sensors).sensordata, -1)
+ return observable.Generic(bodies_ego_orientation)
+
+ @composer.observable
+ def bodies_pos(self):
+ """Position of bodies relative to root, in the egocentric frame."""
+ bodies = self._entity.bodies
+ self._entity.bodies_pos_sensors[:] = []
+ for body in bodies:
+ self._entity.bodies_pos_sensors.append(
+ self._entity.mjcf_model.sensor.add(
+ 'framepos', name=body.name + '_ego_body_pos',
+ objtype='xbody', objname=body,
+ reftype='xbody', refname=self._entity.root_body))
+ def bodies_ego_pos(physics):
+ return np.reshape(
+ physics.bind(self._entity.bodies_pos_sensors).sensordata, -1)
+ return observable.Generic(bodies_ego_pos)
+
+ @property
+ def proprioception(self):
+ return ([self.joints_pos, self.joints_vel,
+ self.body_height, self.end_effectors_pos,
+ self.appendages_pos, self.world_zaxis,
+ self.bodies_quats, self.bodies_pos] +
+ self._collect_from_attachments('proprioception'))
diff --git a/dm_control/locomotion/walkers/ant_test.py b/dm_control/locomotion/walkers/ant_test.py
new file mode 100644
index 00000000..e3ed0083
--- /dev/null
+++ b/dm_control/locomotion/walkers/ant_test.py
@@ -0,0 +1,100 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the Ant."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base as observable_base
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.walkers import ant
+import numpy as np
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.005
+
+
+def _get_ant_corridor_physics():
+ walker = ant.Ant()
+ arena = corr_arenas.EmptyCorridor()
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(5, 0, 0),
+ walker_spawn_rotation=0,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ env = composer.Environment(
+ time_limit=30,
+ task=task,
+ strip_singleton_obs_buffer_dim=True)
+
+ return walker, env
+
+
+class AntTest(parameterized.TestCase):
+
+ def test_can_compile_and_step_simulation(self):
+ _, env = _get_ant_corridor_physics()
+ physics = env.physics
+ for _ in range(100):
+ physics.step()
+
+ @parameterized.parameters([
+ 'egocentric_camera',
+ 'root_body',
+ ])
+ def test_get_element_property(self, name):
+ attribute_value = getattr(ant.Ant(), name)
+ self.assertIsInstance(attribute_value, mjcf.Element)
+
+ @parameterized.parameters([
+ 'actuators',
+ 'end_effectors',
+ 'observable_joints',
+ ])
+ def test_get_element_tuple_property(self, name):
+ attribute_value = getattr(ant.Ant(), name)
+ self.assertNotEmpty(attribute_value)
+ for item in attribute_value:
+ self.assertIsInstance(item, mjcf.Element)
+
+ def test_set_name(self):
+ name = 'fred'
+ walker = ant.Ant(name=name)
+ self.assertEqual(walker.mjcf_model.model, name)
+
+ @parameterized.parameters(
+ 'appendages_pos',
+ 'sensors_touch',
+ )
+ def test_evaluate_observable(self, name):
+ walker, env = _get_ant_corridor_physics()
+ physics = env.physics
+ observable = getattr(walker.observables, name)
+ observation = observable(physics)
+ self.assertIsInstance(observation, (float, np.ndarray))
+
+ def test_proprioception(self):
+ walker = ant.Ant()
+ for item in walker.observables.proprioception:
+ self.assertIsInstance(item, observable_base.Observable)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_1_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_1_body.msh
new file mode 100644
index 00000000..6fd0e732
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_1_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_1_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_1_lower.msh
new file mode 100644
index 00000000..fe06a608
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_1_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_2_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_2_body.msh
new file mode 100644
index 00000000..1d9e39f0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_2_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_2_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_2_lower.msh
new file mode 100644
index 00000000..50c6e35f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_2_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_3_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_3_body.msh
new file mode 100644
index 00000000..bb74f96f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_3_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_3_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_3_lower.msh
new file mode 100644
index 00000000..9c9b4fc6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_3_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_4_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_4_body.msh
new file mode 100644
index 00000000..18f5df14
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_4_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_4_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_4_lower.msh
new file mode 100644
index 00000000..d5536126
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_4_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_5_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_5_body.msh
new file mode 100644
index 00000000..df98adb7
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_5_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_5_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_5_lower.msh
new file mode 100644
index 00000000..c23560ef
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_5_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_6_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_6_body.msh
new file mode 100644
index 00000000..1112a798
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_6_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_6_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_6_lower.msh
new file mode 100644
index 00000000..eb6a4a1f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_6_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_7_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_7_body.msh
new file mode 100644
index 00000000..d6aba9c8
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_7_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_7_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_7_lower.msh
new file mode 100644
index 00000000..a1a40d27
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_7_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_8_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_8_body.msh
new file mode 100644
index 00000000..e5c53f11
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/abdomen_8_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_left_black.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_left_black.msh
new file mode 100644
index 00000000..353f67a6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_left_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_left_body.msh
new file mode 100644
index 00000000..c274dc87
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_right_black.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_right_black.msh
new file mode 100644
index 00000000..9a742a36
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_right_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_right_body.msh
new file mode 100644
index 00000000..a0493295
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/antenna_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T1_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T1_left_body.msh
new file mode 100644
index 00000000..9df055b0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T1_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T1_right_body.msh
new file mode 100644
index 00000000..990dd6c6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T2_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T2_left_body.msh
new file mode 100644
index 00000000..ef693f0e
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T2_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T2_right_body.msh
new file mode 100644
index 00000000..98be03ac
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T3_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T3_left_body.msh
new file mode 100644
index 00000000..256dc0cc
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T3_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T3_right_body.msh
new file mode 100644
index 00000000..3e7814a1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/coxa_T3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila.xml b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila.xml
new file mode 100644
index 00000000..aa7332fb
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila.xml
@@ -0,0 +1,510 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila_defaults.xml b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila_defaults.xml
new file mode 100644
index 00000000..f5ae31b4
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila_defaults.xml
@@ -0,0 +1,224 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila_fused.xml b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila_fused.xml
new file mode 100644
index 00000000..11abf9cc
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/drosophila_fused.xml
@@ -0,0 +1,436 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T1_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T1_left_body.msh
new file mode 100644
index 00000000..59d3fc7b
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T1_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T1_right_body.msh
new file mode 100644
index 00000000..0d385990
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T2_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T2_left_body.msh
new file mode 100644
index 00000000..cd49ae92
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T2_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T2_right_body.msh
new file mode 100644
index 00000000..2804b089
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T3_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T3_left_body.msh
new file mode 100644
index 00000000..a3ce197f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T3_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T3_right_body.msh
new file mode 100644
index 00000000..93819cee
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/femur_T3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haltere_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haltere_left_body.msh
new file mode 100644
index 00000000..bc4d52d1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haltere_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haltere_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haltere_right_body.msh
new file mode 100644
index 00000000..545cbab0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haltere_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haustellum_black.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haustellum_black.msh
new file mode 100644
index 00000000..ce3c56f3
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haustellum_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haustellum_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haustellum_body.msh
new file mode 100644
index 00000000..f9d83b99
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/haustellum_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_black.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_black.msh
new file mode 100644
index 00000000..ffb2f26b
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_body.msh
new file mode 100644
index 00000000..4236c6db
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_ocelli.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_ocelli.msh
new file mode 100644
index 00000000..4a4c281a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_ocelli.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_red.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_red.msh
new file mode 100644
index 00000000..53d30f86
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/head_red.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/labrum_left_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/labrum_left_lower.msh
new file mode 100644
index 00000000..f00f2256
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/labrum_left_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/labrum_right_lower.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/labrum_right_lower.msh
new file mode 100644
index 00000000..b183450c
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/labrum_right_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/rostrum_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/rostrum_body.msh
new file mode 100644
index 00000000..f0bd54e3
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/rostrum_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/rostrum_bristle-brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/rostrum_bristle-brown.msh
new file mode 100644
index 00000000..f5dca535
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/rostrum_bristle-brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T1_left_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T1_left_brown.msh
new file mode 100644
index 00000000..104fa46f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T1_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T1_right_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T1_right_brown.msh
new file mode 100644
index 00000000..51a8ce90
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T1_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T2_left_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T2_left_brown.msh
new file mode 100644
index 00000000..9fd8426a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T2_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T2_right_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T2_right_brown.msh
new file mode 100644
index 00000000..bd8127c6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T2_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T3_left_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T3_left_brown.msh
new file mode 100644
index 00000000..4c224a0d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T3_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T3_right_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T3_right_brown.msh
new file mode 100644
index 00000000..f1a12541
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsal_claw_T3_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_1_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_1_left_body.msh
new file mode 100644
index 00000000..0d2926cc
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_1_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_1_right_body.msh
new file mode 100644
index 00000000..b2238e9a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_2_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_2_left_body.msh
new file mode 100644
index 00000000..f8c0375f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_2_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_2_right_body.msh
new file mode 100644
index 00000000..81eba2ea
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_3_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_3_left_body.msh
new file mode 100644
index 00000000..ae1e169d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_3_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_3_right_body.msh
new file mode 100644
index 00000000..475817de
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_4_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_4_left_body.msh
new file mode 100644
index 00000000..116844cb
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_4_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_4_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_4_right_body.msh
new file mode 100644
index 00000000..97acf78d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T1_4_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_1_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_1_left_body.msh
new file mode 100644
index 00000000..46ed4fb4
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_1_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_1_right_body.msh
new file mode 100644
index 00000000..b49ee4c2
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_2_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_2_left_body.msh
new file mode 100644
index 00000000..7b81306c
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_2_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_2_right_body.msh
new file mode 100644
index 00000000..d1a7d6ef
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_3_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_3_left_body.msh
new file mode 100644
index 00000000..fa26f25a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_3_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_3_right_body.msh
new file mode 100644
index 00000000..e66cb2de
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_4_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_4_left_body.msh
new file mode 100644
index 00000000..72fb71ef
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_4_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_4_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_4_right_body.msh
new file mode 100644
index 00000000..563fb5e0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T2_4_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_1_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_1_left_body.msh
new file mode 100644
index 00000000..5cf0ad6f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_1_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_1_right_body.msh
new file mode 100644
index 00000000..45b35ea0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_2_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_2_left_body.msh
new file mode 100644
index 00000000..3a6e0592
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_2_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_2_right_body.msh
new file mode 100644
index 00000000..7ee5024e
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_3_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_3_left_body.msh
new file mode 100644
index 00000000..9fca2447
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_3_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_3_right_body.msh
new file mode 100644
index 00000000..416e429d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_4_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_4_left_body.msh
new file mode 100644
index 00000000..ae032249
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_4_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_4_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_4_right_body.msh
new file mode 100644
index 00000000..fb4a36d5
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tarsus_T3_4_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/thorax_black.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/thorax_black.msh
new file mode 100644
index 00000000..652ba9fc
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/thorax_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/thorax_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/thorax_body.msh
new file mode 100644
index 00000000..31965586
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/thorax_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T1_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T1_left_body.msh
new file mode 100644
index 00000000..0c6104c1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T1_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T1_right_body.msh
new file mode 100644
index 00000000..533335e2
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T2_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T2_left_body.msh
new file mode 100644
index 00000000..ba38624c
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T2_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T2_right_body.msh
new file mode 100644
index 00000000..61ab8cf5
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T3_left_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T3_left_body.msh
new file mode 100644
index 00000000..c30f9d4f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T3_right_body.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T3_right_body.msh
new file mode 100644
index 00000000..8a5575a1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/tibia_T3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_left_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_left_brown.msh
new file mode 100644
index 00000000..7c30f774
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_left_membrane.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_left_membrane.msh
new file mode 100644
index 00000000..a07ca11a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_left_membrane.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_right_brown.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_right_brown.msh
new file mode 100644
index 00000000..ee5f18ec
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_right_membrane.msh b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_right_membrane.msh
new file mode 100644
index 00000000..3550adf8
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/build_fruitfly/assets/wing_right_membrane.msh differ
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/floor.xml b/dm_control/locomotion/walkers/assets/build_fruitfly/floor.xml
new file mode 100644
index 00000000..98b1861b
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/build_fruitfly/floor.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/fruitfly.xml b/dm_control/locomotion/walkers/assets/build_fruitfly/fruitfly.xml
new file mode 100644
index 00000000..a68e3e39
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/build_fruitfly/fruitfly.xml
@@ -0,0 +1,918 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/fuse_fruitfly.py b/dm_control/locomotion/walkers/assets/build_fruitfly/fuse_fruitfly.py
new file mode 100644
index 00000000..7fc11272
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/build_fruitfly/fuse_fruitfly.py
@@ -0,0 +1,64 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Fuse fruitfly model."""
+
+import os
+from typing import Sequence
+
+from absl import app
+
+from dm_control import mujoco
+from lxml import etree
+
+ASSET_RELPATH = 'assets/'
+ASSET_DIR = os.path.dirname(__file__) + '/' + ASSET_RELPATH
+BASE_MODEL = 'drosophila_defaults.xml'
+FLY_MODEL = 'drosophila.xml' # Raw model as exported from Blender.
+FUSED_MODEL = ASSET_DIR + 'drosophila_fused.xml'
+
+
+def main(argv: Sequence[str]):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ print('Load base model.')
+ with open(os.path.join(ASSET_DIR, FLY_MODEL), 'r') as f:
+ tree = etree.XML(f.read(), etree.XMLParser(remove_blank_text=True))
+
+ print('Remove lights.')
+ lights = tree.xpath('.//light')
+ for light in lights:
+ light.getparent().remove(light)
+
+ print('Set fusestatic option.')
+ compiler = tree.find('compiler')
+ compiler.attrib['fusestatic'] = 'true'
+ del compiler.attrib['boundmass']
+ del compiler.attrib['boundinertia']
+
+ print('Add freejoint.')
+ root = tree.find('worldbody').find('body')
+ root.getchildren()[0].addprevious(etree.Element('freejoint'))
+
+ print('Load physics, fuse.')
+ physics = mujoco.Physics.from_xml_string(etree.tostring(tree,
+ pretty_print=True))
+
+ print('Save fused model.')
+ mujoco.mj_saveLastXML(os.path.join(FUSED_MODEL), physics.model.ptr)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/locomotion/walkers/assets/build_fruitfly/make_fruitfly.py b/dm_control/locomotion/walkers/assets/build_fruitfly/make_fruitfly.py
new file mode 100644
index 00000000..b31bcaf2
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/build_fruitfly/make_fruitfly.py
@@ -0,0 +1,1214 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Make fruitfly model."""
+
+import itertools
+import os
+from typing import Sequence
+
+from absl import app
+
+from dm_control import mjcf
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+from lxml import etree
+import numpy as np
+
+ASSET_RELPATH = 'assets/'
+ASSET_DIR = os.path.dirname(__file__) + '/' + ASSET_RELPATH
+BASE_MODEL = 'drosophila_defaults.xml'
+FLY_MODEL = 'drosophila_fused.xml' # Pre-generated by fuse_fruitfly.py
+
+FINAL_MODEL = 'fruitfly.xml' # Output file of this script.
+
+_YAW_AXIS_PITCH = -47.5 * np.pi / 180
+
+# Empirical mass values, milligrams.
+_MASS = {'head': 0.15,
+ 'thorax': 0.34,
+ 'abdomen': 0.38,
+ 'leg': 0.0162,
+ 'wing': 0.008}
+
+
+# Utility functions:
+def mul_quat(quat_a, quat_b):
+ """Returns quat_a * quat_b."""
+ quat_c = np.zeros(4)
+ mjlib.mju_mulQuat(quat_c, quat_a, quat_b)
+ return quat_c
+
+
+def quat_to_mat(quat):
+ """Converts quaternion to rotation matrix."""
+ mat = np.zeros(9)
+ mjlib.mju_quat2Mat(mat, quat)
+ return mat.reshape(3, 3)
+
+
+def mat_to_quat(mat):
+ """Converts rotation matrix to quaternion."""
+ quat = np.zeros(4)
+ mjlib.mju_mat2Quat(quat, mat.flatten())
+ return quat
+
+
+def neg_quat(quat_a):
+ """Returns neg(quat_a) * quat_b."""
+ neg_quat_a = quat_a.copy()
+ neg_quat_a[0] *= -1
+ return neg_quat_a
+
+
+def rot_vec_quat(vec, quat):
+ rot = np.zeros(3)
+ mjlib.mju_rotVecQuat(rot, vec, quat)
+ return rot
+
+
+def quat_z2vec(vec):
+ """Construct quaternion performing rotation from z-axis to given vector."""
+ quat = np.zeros(4)
+ mjlib.mju_quatZ2Vec(quat, vec)
+ return quat
+
+
+def change_body_frame(body, frame_pos, frame_quat):
+ """Change the frame of a body while maintaining child locations."""
+ frame_pos = np.zeros(3) if frame_pos is None else frame_pos
+ frame_quat = np.array((1., 0, 0, 0)) if frame_quat is None else frame_quat
+ # Get frame transformation.
+ body_pos = np.zeros(3) if body.pos is None else body.pos
+ dpos = body_pos - frame_pos
+ body_quat = np.array((1., 0, 0, 0)) if body.quat is None else body.quat
+ dquat = mul_quat(neg_quat(frame_quat), body_quat)
+ # Translate and rotate the body to the new frame.
+ body.pos = frame_pos
+ body.quat = frame_quat
+ # Move all its children to their previous location.
+ for child in body.all_children():
+ if not hasattr(child, 'pos'):
+ continue
+ # Rotate:
+ if hasattr(child, 'quat'):
+ child_quat = np.array((1., 0, 0, 0)) if child.quat is None else child.quat
+ child.quat = mul_quat(dquat, child_quat)
+ # Translate, accounting for rotations.
+ child_pos = np.zeros(3) if child.pos is None else child.pos
+ pos_in_parent = rot_vec_quat(child_pos, body_quat) + dpos
+ child.pos = rot_vec_quat(pos_in_parent, neg_quat(frame_quat))
+
+
+def main(argv: Sequence[str]):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ print('Load base models.')
+ with open(os.path.join(ASSET_DIR, BASE_MODEL), 'r') as f:
+ modeltree = etree.XML(f.read(), etree.XMLParser(remove_blank_text=True))
+ with open(os.path.join(ASSET_DIR, FLY_MODEL), 'r') as f:
+ flytree = etree.XML(f.read(), etree.XMLParser(remove_blank_text=True))
+
+ print('Combine fly model with defaults.')
+ worldbody = modeltree.find('worldbody')
+ worldbody.addprevious(flytree.find('asset'))
+
+ print('Append root.')
+ all_bodies = flytree.xpath('.//body')
+ thorax = None
+ for body in all_bodies:
+ if body.get('name') is not None and 'Armature' == body.get('name'):
+ thorax = body
+ if thorax is not None:
+ worldbody.append(thorax)
+
+ print('Load as mjcf model.')
+ model = mjcf.from_xml_string(etree.tostring(modeltree, pretty_print=True),
+ model_dir=ASSET_DIR)
+
+ print('Fix shininess.')
+ for material in model.find_all('material'):
+ material.shininess = round(material.shininess * 1e6) * 1e-6
+
+ print('Add global cameras.')
+ model.worldbody.add('camera', name='hero', pos=(0.271, 0.270, -0.044),
+ xyaxes=(-0.641, 0.767, 0, -0.045, -0.037, 0.998))
+
+ print('Remove _Armature suffix.')
+ things = [model.find_all(thing) for thing in ['body', 'joint', 'geom']]
+ things = sum(things, [])
+ for thing in things:
+ if thing.name is not None:
+ thing.name = thing.name.replace('_Armature', '')
+
+ print('Remove _body suffix.')
+ for geom in model.find_all('geom'):
+ if geom.name.endswith('_body'):
+ geom.name = geom.name.replace('_body', '')
+ for mesh in model.find_all('mesh'):
+ if mesh.name.endswith('_body'):
+ mesh.name = mesh.name.replace('_body', '')
+
+ print('Rescale mm -> cm, remove (0, 0, 0).')
+ for thing in things:
+ if thing.pos is not None:
+ thing.pos = thing.pos / 10.0
+ if np.all(thing.pos == 0):
+ thing.pos = None
+
+ print('Use radians.')
+ model.compiler.angle = 'radian'
+
+ print('Set autolimits="true".')
+ model.compiler.autolimits = 'true'
+
+ print('Remove inertial clauses.')
+ for body in model.find_all('body'):
+ children = body.all_children()
+ for child in children:
+ if child.tag == 'inertial':
+ child.remove()
+
+ print('Make wings translucent.')
+ model.find('material', 'membrane').rgba[3] = 0.4
+
+ print('Get, rename thorax.')
+ thorax = model.find('body', 'Armature')
+ thorax.name = 'thorax'
+
+ print('Rotate thorax to face positive x-axis.')
+ thorax.quat = (0, 0, 0, 1.)
+
+ print('Set thorax mass.')
+ thorax.find('geom', 'thorax').mass = _MASS['thorax'] * 1e-3
+
+ print('Rename freejoint.')
+ freejoint = thorax.get_children('joint')[0]
+ freejoint.remove()
+ thorax.insert('freejoint', 0, name='free')
+
+ print('Sort thorax children.')
+ sort_order = ['free', 'thorax', 'head', 'wing', 'abdomen',
+ 'haltere_left', 'haltere_right', 'coxa']
+ child_names = [e.name for e in thorax._children] # pylint: disable=protected-access
+ resort = []
+ for type_name in sort_order:
+ indices = [j for j, s in enumerate(child_names) if type_name in s]
+ resort.extend(indices)
+ assert len(child_names) == len(resort)
+ thorax._children = [thorax._children[i] for i in resort] # pylint: disable=protected-access
+
+ print('Sort labrums.')
+ haustellum = thorax.find('body', 'haustellum')
+ sort_order = ['haustellum', 'labrum_left', 'labrum_right']
+ child_names = [e.name for e in haustellum._children] # pylint: disable=protected-access
+ resort = []
+ for type_name in sort_order:
+ indices = [j for j, s in enumerate(child_names) if type_name in s]
+ resort.extend(indices)
+ assert len(child_names) == len(resort)
+ haustellum._children = [haustellum._children[i] for i in resort] # pylint: disable=protected-access
+
+ print('Retract haustellum.')
+ haustellum.find('joint', 'rx_haustellum').range[1] = 0.7
+ haustellum.find('joint', 'rx_haustellum').springref = 0.8
+
+ print('Set body childclass.')
+ thorax.childclass = 'body'
+
+ print('Remove default-specified values.')
+ for joint in thorax.find_all('joint'):
+ if joint.tag != 'freejoint' and joint.limited is not None:
+ joint.limited = None
+ for geom in thorax.find_all('geom'):
+ if geom.type is not None and geom.type == 'mesh':
+ geom.type = None
+ if geom.material is not None and geom.material.name == 'body':
+ geom.material = None
+
+ print('Rename leg elements.')
+ legs = [body for body in thorax.find_all('body') if 'coxa' in body.name]
+ links = ['coxa', 'femur', 'tibia', 'tarsus', 'tarsus2',
+ 'tarsus3', 'tarsus4']
+ sternums = ['T1', 'T2', 'T3']
+ sides = ['right', 'left']
+ for sternum in sternums:
+ for side in sides:
+ for leg in legs:
+ if sternum in leg.name and side in leg.name:
+ body = leg
+ for link in links:
+ name = link +'_'+ sternum +'_'+ side
+ old_name = body.name
+ body.name = name
+ for geom in body.get_children('geom'):
+ if old_name in geom.name:
+ geom.name = geom.name.replace(old_name, name)
+ for joint in body.get_children('joint'):
+ if old_name in joint.name:
+ joint.name = joint.name.replace(old_name, name)
+ children = body.get_children('body')
+ if children:
+ body = children[0]
+
+ print('Remove tarsus abductors.')
+ tarsi = [body for body in thorax.find_all('body') if 'tarsus' in body.name]
+ for tarsus in tarsi:
+ for joint in tarsus.get_children('joint'):
+ if 'rz' in joint.name and 'tarsus_' not in joint.name:
+ joint.remove()
+
+ print('Add claw bodies and joints.')
+ tarsi4 = [body for body in thorax.find_all('body') if 'tarsus4' in body.name]
+ for tarsus4 in tarsi4:
+ gclaw = [geom for geom in tarsus4.get_children('geom')
+ if 'claw' in geom.name][0]
+ ipos = tarsus4.pos
+ claw = tarsus4.add('body', name=tarsus4.name.replace('tarsus4', 'claw'),
+ pos=ipos)
+ name = gclaw.name.replace('_brown', '')
+ material = gclaw.material
+ mesh = gclaw.mesh
+ quat = gclaw.quat
+ pos = gclaw.pos - ipos
+ gclaw.remove()
+ claw.add('geom', name=name, material=material, quat=quat, mesh=mesh,
+ pos=pos)
+ joint_name = name.replace('tarsal_claw', 'rx_tarsus5')
+ claw.insert('joint', 0, name=joint_name, axis=(1, 0, 0), pos=(0, 0, 0),
+ range=(-.1, .1))
+
+ print('Hair geoms should not contribute mass.')
+ black_material = model.find('material', 'black')
+ for geom in model.find_all('geom'):
+ if geom.material == black_material:
+ geom.mass = 0
+
+ print('Symmetrize thorax children.')
+ thorax_children = list(thorax.all_children())
+ for left in thorax_children:
+ if 'left' in left.name:
+ right = thorax.find('body', left.name.replace('left', 'right'))
+ pos = (right.pos + left.pos * np.array([1, -1, 1]))/2
+ right.pos = pos
+ left.pos = pos * np.array([1, -1, 1])
+
+ print('Make collision materials.')
+ base = model.asset.find('material', 'base')
+ base.name = 'blue'
+ base.rgba = (0.2, 0.3, 1, 1)
+ model.asset.add('material', name='pink', rgba=(0.6, 0.3, 1, 1))
+
+ print('== Infer collision geoms.')
+ print('Infer leg collision geoms.')
+ legs_meshes = []
+ for leg in legs:
+ legs_meshes.append(leg.find_all('geom'))
+ for leg_meshes in legs_meshes:
+ for geom in leg_meshes:
+ if 'coxa' in geom.name:
+ geom.type = 'ellipsoid'
+ else:
+ geom.type = 'capsule'
+
+ print('Infer wing collision geoms.')
+ wings = [body for body in thorax.find_all('body') if 'wing' in body.name]
+ wing_meshes = []
+ for wing in wings:
+ for geom in wing.find_all('geom'):
+ wing_meshes.append(geom)
+ geom.type = 'ellipsoid'
+
+ print('Infer mouth collision geoms.')
+ haustellum_mesh = thorax.find('geom', 'haustellum')
+ haustellum_mesh.type = 'capsule'
+ mouth_meshes = [g for g in thorax.find_all('geom') if 'labrum' in g.name]
+ for geom in mouth_meshes:
+ geom.type = 'ellipsoid'
+ mouth_meshes.append(haustellum_mesh)
+
+ print('Recompile.')
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ print('Re-center thorax CoM at origin.')
+ offset = physics.bind(thorax).xipos
+ offset[1] = 0.0
+ for geom in thorax.get_children('geom'):
+ geom.pos -= offset
+ if geom.quat is not None and geom.quat[0] == 1:
+ geom.quat = None
+ for body in thorax.get_children('body'):
+ body.pos -= offset
+
+ print('Set T3 coxa frame positions.')
+ for leg in legs:
+ if 'T3' in leg.name:
+ pos = np.array((0.021875, 0.01925, -0.04025))
+ if 'left' in leg.name:
+ pos[1] *= -1
+ change_body_frame(leg, pos, leg.quat)
+ # reset joint positions
+ for joint in leg.get_children('joint'):
+ joint.pos = None
+
+ print('Set leg frames, in reverse order.')
+ for leg in legs:
+ coxa = physics.bind(leg)
+ femur = leg.get_children('body')[0]
+ tibia = femur.get_children('body')[0]
+ tarsus = tibia.get_children('body')[0]
+ claw = leg.find('body', leg.name.replace('coxa', 'claw'))
+ bound_femur = physics.bind(femur)
+ bound_tibia = physics.bind(tibia)
+ bound_tarsus = physics.bind(tarsus)
+ bound_claw = physics.bind(claw)
+
+ if 'left' in leg.name:
+ quat = np.array((0., 1., 0, 0))
+ else:
+ quat = np.array((1., 0, 0, 0))
+
+ # upper tarsus
+ if 'T2' in leg.name or 'T3' in leg.name:
+ tarsus_to_claw = bound_claw.xpos - bound_tarsus.xpos
+ tarsus_to_tibia = bound_tibia.xpos - bound_tarsus.xpos
+ extend = -np.cross(tarsus_to_claw, tarsus_to_tibia)
+ extend /= np.linalg.norm(extend)
+ twist = tarsus_to_claw
+ twist /= np.linalg.norm(twist)
+ if 'right' in leg.name:
+ twist *= -1
+ abduct = np.cross(-extend, twist)
+ xmat = np.vstack((-extend, twist, abduct)).T
+ mat = bound_tibia.xmat.reshape(3, 3).T @ xmat
+ tarsus_pos = tarsus.pos.copy()
+ tarsus_pos[2] -= .00175
+ tarsus_quat = mat_to_quat(mat)
+ else:
+ if 'left' in leg.name:
+ rquat = np.array((1., 0, 0, 0))
+ else:
+ rquat = np.array((0., 1., 0, 0))
+ tarsus_quat = mul_quat(rquat, tarsus.quat)
+ tarsus_pos = tarsus.pos.copy()
+ tarsus_pos[2] -= .00175
+ change_body_frame(tarsus, tarsus_pos, tarsus_quat)
+
+ # set lower tarsi to match upper tarsus
+ parent = tarsus
+ children = parent.get_children('body')
+ while children:
+ child = children[0]
+ change_body_frame(child, child.pos, np.array((1., 0, 0, 0)))
+ parent = child
+ children = parent.get_children('body')
+
+ # tibia
+ change_body_frame(tibia, tibia.pos, mul_quat(tibia.quat, quat))
+
+ # femur
+ femur_to_tibia = bound_tibia.xpos - bound_femur.xpos
+ femur_to_coxa = coxa.xpos - bound_femur.xpos
+ extend = np.cross(femur_to_tibia, femur_to_coxa)
+ extend /= np.linalg.norm(extend)
+ twist = femur_to_tibia
+ twist /= np.linalg.norm(twist)
+ if 'right' in femur.name:
+ twist *= -1
+ abduct = np.cross(-extend, twist)
+ xmat = np.vstack((-extend, twist, abduct)).T
+ mat = coxa.xmat.reshape(3, 3).T @ xmat
+ change_body_frame(femur, femur.pos, mat_to_quat(mat))
+
+ # coxa
+ twist = -femur_to_coxa
+ twist /= np.linalg.norm(twist)
+ if 'right' in leg.name:
+ twist *= -1
+ abduct = np.cross(extend, twist)
+ xmat = np.vstack((extend, twist, abduct)).T
+ mat = physics.bind(thorax).xmat.reshape(3, 3).T @ xmat
+ change_body_frame(leg, leg.pos, mat_to_quat(mat))
+
+ print('Recompile.')
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ print('== Make collision geoms.')
+ print('Make leg collision geoms.')
+ for leg_meshes in legs_meshes:
+ fromto = None
+ for geom in leg_meshes:
+ pos = physics.bind(geom).pos
+ quat = physics.bind(geom).quat
+ size = physics.bind(geom).size
+ children = geom.parent.get_children('body')
+ gtype = geom.type
+ if len(children) == 1:
+ dclass = 'collision'
+ if 'coxa' in geom.name:
+ fromto = None
+ if 'T3' in geom.name:
+ axis = children[0].pos / np.linalg.norm(children[0].pos)
+ quat = quat_z2vec(axis)
+ size = np.array((0.007, 0.00875, 0.016625))
+ pos[1] -= 0.00175 * (1 if 'left' in geom.name else -1)
+ elif 'T2' in geom.name:
+ axis = children[0].pos / np.linalg.norm(children[0].pos)
+ quat = quat_z2vec(axis)
+ size = np.array((0.007875, 0.007, 0.014875))
+ pos *= 0.7875
+ else:
+ size *= 1.225
+ else:
+ quat = None
+ pos = None
+ from_ = [0, 0, 0]
+ if 'femur' in geom.name:
+ from_[2] = .004375 * (1 if 'left' in geom.name else -1)
+ from_[1] = .002625 * (1 if 'left' in geom.name else -1)
+ fromto = np.hstack((from_, children[0].pos*0.95))
+ if 'T3' in geom.name:
+ fromto[0] += 0.002625 * (-1 if 'left' in geom.name else 1)
+ fromto[3] += 0.002625 * (-1 if 'left' in geom.name else 1)
+ fromto[4] += 0.002625 * (1 if 'left' in geom.name else -1)
+ if 'T1' in geom.name:
+ fromto[0] += .002625 * (1 if 'left' in geom.name else -1)
+ size = (1.05 * size[0],)
+ elif 'tibia' in geom.name and 'T3' in geom.name:
+ fromto = np.hstack((from_, children[0].pos))
+ fromto[1] += .003 * (-1 if 'left' in geom.name else 1)
+ fromto[2] += .005 * (-1 if 'left' in geom.name else 1)
+ size = (1.05 * size[0],)
+ else:
+ fromto = np.hstack((from_, children[0].pos))
+ size = (1.225 * size[0],)
+ gtype = None
+ else:
+ dclass = 'adhesion-collision'
+ pos = None
+ quat = None
+ size = (1.225 * size[0],)
+ name = geom.name + '_collision'
+ index = geom.parent._children.index(geom) # pylint: disable=protected-access
+ geom.parent.insert('geom', index+1, fromto=fromto, size=size, quat=quat,
+ name=name, type=gtype, dclass=dclass, pos=pos)
+ geom.type = None
+
+ print('Make wing collision geoms.')
+ for geom in wing_meshes:
+ pos = physics.bind(geom).pos
+ quat = physics.bind(geom).quat
+ name = geom.name + '_collision'
+ # Adjust MuJoCo's geom fits.
+ if 'membrane' in name:
+ size = (0.0030625, 0.055125, 0.11375)
+ angle = 0.11 * (1 if 'right' in name else -1)
+ lateral = 0.02625 * (-1 if 'right' in name else 1)
+ forward = -0.002625
+ in_plane = np.array([[forward], [lateral]])
+ else:
+ size = (0.0021875, 0.0175, 0.11375)
+ angle = 0.05 * (1 if 'right' in name else -1)
+ lateral = 0.0035 * (-1 if 'right' in name else 1)
+ forward = 0.0100625
+ in_plane = np.array([[forward], [lateral]])
+ mat = quat_to_mat(physics.bind(geom).quat)
+ offset = mat[:, 1:3] @ in_plane
+ pos = pos + offset.flatten()
+ rotate = np.array((np.cos(angle/2), np.sin(angle/2), 0, 0))
+ quat = mul_quat(quat, rotate)
+
+ index = geom.parent._children.index(geom) # pylint: disable=protected-access
+ colgeom = geom.parent.insert('geom', index+1, pos=pos, quat=quat, size=size,
+ name=name, type=geom.type, dclass='collision')
+ geom.type = None
+ # set wing inertias using a custom geom
+ geom.mass = 0
+ if 'membrane' in name:
+ colgeom.dclass = 'collision-membrane'
+ # add fluid-interaction geoms
+ gsize = colgeom.size.copy()
+ gsize[0] = 0.0005 # 2 microns
+ gname = geom.parent.name + '_inertial'
+ geom.parent.insert('geom', index+2, pos=colgeom.pos, quat=colgeom.quat,
+ name=gname, dclass='wing-inertial', size=gsize)
+ gname = geom.parent.name + '_fluid'
+ geom.parent.insert('geom', index+2, pos=colgeom.pos, quat=colgeom.quat,
+ name=gname, dclass='wing-fluid', size=gsize)
+
+ print('Make mouth collision geoms.')
+ for geom in mouth_meshes:
+ pos = physics.bind(geom).pos
+ quat = physics.bind(geom).quat
+
+ # Adjust MuJoCo's geom fits
+ if 'haustellum' in geom.name:
+ size = (0.007875, 0.007875)
+ else:
+ size = (0.0035, 0.00875, 0.013125)
+ pos *= 1.22
+ # save collision geom
+ name = geom.name + '_collision'
+ index = geom.parent._children.index(geom) # pylint: disable=protected-access
+ dclass = 'adhesion-collision' if 'labrum' in geom.name else 'collision'
+ geom.parent.insert('geom', index+1, pos=pos, quat=quat, size=size,
+ name=name, type=geom.type, dclass=dclass)
+ geom.type = None
+
+ print('Make antennae collision geoms.')
+ antennae = [b for b in thorax.find_all('body') if 'antenna' in b.name]
+ for antenna in antennae:
+ pos = (0, 0.011375, 0.002625)
+ zaxis = (0, 0.875, -0.175)
+ size = (0.0048125, 0.00875)
+ name = antenna.name + '_collision'
+ quat = np.zeros(4)
+ mjlib.mju_quatZ2Vec(quat, np.asarray(zaxis))
+ antenna.insert('geom', -1, name=name, type='capsule', dclass='collision',
+ pos=pos, quat=quat, size=size)
+
+ print('Make abdomen collision geoms.')
+ abdomens = [b for b in thorax.find_all('body') if 'abdomen' in b.name]
+ for abdomen in abdomens:
+ name = abdomen.name + '_collision'
+ inertia = physics.bind(abdomen).inertia
+ mass = physics.bind(abdomen).mass
+ pos = physics.bind(abdomen).ipos
+ quat = physics.bind(abdomen).iquat
+ # Get inertia box
+ size = np.zeros(3)
+ for i in range(3):
+ not_i = set([0, 1, 2]) - set([i])
+ accum = 0.0
+ for j in not_i:
+ accum += inertia[j]
+ accum -= inertia[i]
+ size[i] = np.sqrt(accum / mass * 6) / 2
+
+ if '7' in abdomen.name:
+ size = (0.02625,)
+ gtype = 'sphere'
+ quat = None
+ else:
+ radius = size[1:3].max()
+ height = size[0]
+ # axis = children[0].pos / np.linalg.norm(children[0].pos)
+ if '1' in abdomen.name:
+ angle = np.pi/2
+ rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
+ quat = mul_quat(quat, rotate)
+ pos[2] -= 0.00525
+ height *= 1.5
+ else:
+ # quat = quat_z2vec(np.array((0, 1., 0)))
+ angle = np.pi/2 + 0.1
+ rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
+ quat = mul_quat(quat, rotate)
+ pos[2] -= 0.0875 * size[0]
+ radius *= 1.05
+ if '3' in abdomen.name:
+ pos[2] += 0.0021875
+ if '6' in abdomen.name:
+ angle = - 0.096
+ rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
+ quat = mul_quat(quat, rotate)
+ size = np.array([radius, height, 0])
+ gtype = 'cylinder'
+ # Make collision geom.
+ abdomen.insert('geom', 4, type=gtype, dclass='collision', size=size,
+ pos=pos, quat=quat, name=name)
+
+ print('Add bespoke collision geoms.')
+ # Thorax.
+ angle = -1.
+ rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
+ thorax.insert('geom', 2, name='thorax_collision', dclass='collision',
+ type='ellipsoid', size=(0.04375, 0.04375, 0.055125),
+ pos=(-0.0175, 0, -0.002625), quat=rotate)
+ angle = -.4
+ rotate = np.array((np.cos(angle/2), 0, np.sin(angle/2), 0))
+ thorax.insert('geom', 3, name='thorax_collision2', dclass='collision',
+ type='ellipsoid', size=(0.049875, 0.028, 0.011375),
+ pos=(-0.011375, 0, 0.021875), quat=rotate)
+ # Head.
+ head = thorax.find('body', 'head')
+ head.insert('site', 0, name='head', pos=(0, 0.015, 0), quat=(0, 0, 0, 1))
+ head.insert('geom', 4, name='head_collision', dclass='collision',
+ type='ellipsoid', size=(0.0455, 0.02625, 0.032375),
+ pos=(0, 0.014875, 0.000875), euler=(.3, 0, 0))
+ # Eye cameras.
+ mat = np.zeros((3, 3))
+ mat[:, 0] = [.45, -1, -.3]
+ mat[:, 0] /= np.linalg.norm(mat[:, 0])
+ mat[:, 1] = [-.2, 0, 1]
+ mat[:, 1] /= np.linalg.norm(mat[:, 1])
+ mat[:, 1] -= mat[:, 0] * np.dot(mat[:, 0], mat[:, 1])
+ mat[:, 1] /= np.linalg.norm(mat[:, 1])
+ mat[:, 2] = np.cross(mat[:, 0], mat[:, 1])
+ head.insert('camera', 8, name='eye_right', fovy=140,
+ pos=(0.021875, 0.013125, 0), quat=mat_to_quat(mat))
+ mat[:, 0] = [.45, 1, .3]
+ mat[:, 0] /= np.linalg.norm(mat[:, 0])
+ mat[:, 1] = [.2, 0, 1]
+ mat[:, 1] /= np.linalg.norm(mat[:, 1])
+ mat[:, 1] -= mat[:, 0] * np.dot(mat[:, 0], mat[:, 1])
+ mat[:, 1] /= np.linalg.norm(mat[:, 1])
+ mat[:, 2] = np.cross(mat[:, 0], mat[:, 1])
+ head.insert('camera', 9, name='eye_left', fovy=140,
+ pos=(-0.021875, 0.013125, 0), quat=mat_to_quat(mat))
+ # Rostrum.
+ rostrum = thorax.find('body', 'rostrum')
+ rostrum.insert('geom', 2, name='rostrum_collision', dclass='collision',
+ type='ellipsoid', size=(0.013125, 0.021875, 0.013125),
+ pos=(0, 0.0175, 0.002625), euler=(.1, 0, 0))
+ for side in ['_left', '_right']:
+ fromto = np.array((-0.006125, 0.032375, 0, -0.011375, 0.0245, -0.023625))
+ if 'r' in side:
+ fromto[0] *= -1
+ fromto[3] *= -1
+ rostrum.insert('geom', 2, name='rostrum_collision'+side, dclass='collision',
+ size=(0.0035,), fromto=fromto)
+
+ print('Recompile.')
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ print('== Finalise joints and defaults.')
+
+ print('Wing joints.')
+ # Set pitch joint range.
+ pitch_range = 2. / 3. * np.pi * np.array((-1., 1.)) - _YAW_AXIS_PITCH
+ pitch_range = np.round(pitch_range*100)/100
+ model.find('default', 'pitch').joint.range = pitch_range
+ wing_quats = [np.array([np.cos(angle/2), 0, np.sin(angle/2), 0]) for
+ angle in [_YAW_AXIS_PITCH, _YAW_AXIS_PITCH + np.pi]]
+
+ # Set wing joints.
+ for i, wing in enumerate(wings):
+ wing.childclass = model.find('default', 'wing')
+ change_body_frame(wing, wing.pos, wing_quats[i])
+ for joint in wing.get_children('joint'):
+ joint.range = None
+ joint.axis = None
+ joint.pos = None
+ if 'rx' in joint.name:
+ joint.name = joint.name.replace('rx', 'yaw')
+ joint.dclass = 'yaw'
+ if 'ry' in joint.name:
+ joint.name = joint.name.replace('ry', 'roll')
+ joint.dclass = 'roll'
+ if 'rz' in joint.name:
+ joint.name = joint.name.replace('rz', 'pitch')
+ joint.dclass = 'pitch'
+
+ print('Symmetrize wing geoms children.')
+ wing_left = thorax.find('body', 'wing_left')
+ wing_right = thorax.find('body', 'wing_right')
+ for lgeom in wing_left.find_all('geom'):
+ rgeom = wing_right.find('geom', lgeom.name.replace('left', 'right'))
+ pos = (lgeom.pos - rgeom.pos)/2
+ lgeom.pos = pos
+ rgeom.pos = -pos
+
+ print('Symmetrize Antennae left->right.')
+ head_xmat = physics.bind(head).xmat.reshape(3, 3)
+ left_xmat = physics.bind(antennae[0]).xmat.reshape(3, 3)
+ right_xmat = left_xmat * np.array(([-1], [1], [-1]))
+ right_mat = head_xmat.T @ right_xmat
+ right_quat = mat_to_quat(right_mat)
+ change_body_frame(antennae[1], antennae[1].pos, right_quat)
+
+ # Axis names dict:
+ ax_names = {'rx': 'extend', 'ry': 'twist', 'rz': 'abduct'}
+
+ print('Reorder all joint axes: (ry, rz, rx).')
+ for body in model.find_all('body'):
+ joint_index = np.array((-1, -1, -1))
+ for joint in body.get_children('joint'):
+ for i, axis in enumerate(ax_names):
+ if axis in joint.name:
+ joint_index[i] = body._children.index(joint) # pylint: disable=protected-access
+ if sum(joint_index >= 0) > 1:
+ joint_reorder = [joint_index[i] for i in [2, 1, 0] if joint_index[i] >= 0]
+ joint_reorder = list(filter(lambda a: a != -1, joint_reorder))
+ joint_index = list(filter(lambda a: a != -1, joint_index))
+ child_order = list(range(len(body._children))) # pylint: disable=protected-access
+ for i, index in enumerate(joint_index):
+ child_order[index] = joint_reorder[i]
+ body._children = [body._children[i] for i in child_order] # pylint: disable=protected-access
+
+ print('Rename all joints.')
+ for joint in model.find_all('joint'):
+ for axis in ax_names:
+ joint.name = joint.name.replace(axis, ax_names[axis])
+
+ print('Head joints.')
+ head_joint_range = {'twist': (-3, 3),
+ 'abduct': (-.2, .2),
+ 'extend': (-.5, .3)}
+ head.childclass = model.find('default', 'head')
+ for joint in head.get_children('joint'):
+ for joint_range in head_joint_range:
+ if joint_range in joint.name:
+ joint.range = head_joint_range[joint_range]
+ for i, side in enumerate(sides):
+ labrum = head.find('body', 'labrum_' + side)
+ labrum.childclass = model.find('default', 'labrum')
+ pos = labrum.pos - (i*2-1) * np.array((0.002625, 0, 0))
+ pos -= np.array((0, 0.002625, 0))
+ change_body_frame(labrum, pos, labrum.quat)
+ for joint in labrum.get_children('joint'):
+ joint.pos = None
+
+ print('Abdominal joints.')
+ abdomens[0].childclass = model.find('default', 'abdomen')
+ abdomens[0].name = abdomens[0].name.replace('abdomen_1', 'abdomen')
+ for child in abdomens[0]._children: # pylint: disable=protected-access
+ child.name = child.name.replace('abdomen_1', 'abdomen')
+ def_ab_abduct = model.find('default', 'abduct_abdomen')
+ def_ab_extend = model.find('default', 'extend_abdomen')
+ for abdomen in abdomens:
+ for joint in abdomen.get_children('joint'):
+ if 'extend' in joint.name:
+ joint.dclass = def_ab_extend
+ else:
+ joint.dclass = def_ab_abduct
+ joint.axis = None
+ joint.range = None
+
+ print('Haltere joints.')
+ halteres = [b for b in thorax.find_all('body') if 'haltere' in b.name]
+ for haltere in halteres:
+ if 'left' in haltere.name:
+ rotate_y = np.array((0., 0., 1., 0.))
+ change_body_frame(haltere, haltere.pos, mul_quat(haltere.quat, rotate_y))
+ for joint in haltere.get_children('joint'):
+ joint.pos = None
+ joint.axis = None
+ joint.range = None
+ haltere.childclass = model.find('default', 'haltere')
+
+ print('Antennae joints.')
+ for antenna in antennae:
+ for joint in antenna.get_children('joint'):
+ for axis in ax_names.values():
+ if axis in joint.name:
+ joint.dclass = model.find('default', 'antenna_' + axis)
+ joint.axis = None
+ joint.range = None
+ joint.pos = None
+
+ print('Leg joints.')
+ # abduct_femur -> twist_femur
+ for leg in legs:
+ for joint in leg.find_all('joint'):
+ if 'abduct_femur' in joint.name:
+ joint.name = joint.name.replace('abduct_femur', 'twist_femur')
+
+ for leg in legs:
+ leg.childclass = model.find('default', 'leg')
+ parent = leg
+ while parent:
+ # set joint properties
+ for joint in parent.get_children('joint'):
+ joint.pos = None
+ joint.range = None
+ joint.axis = None
+ def_name = joint.name.replace('_left', '').replace('_right', '')
+ if not model.find('default', def_name):
+ def_name = def_name[:-3] # remove _TX
+ if not model.find('default', def_name):
+ def_name = def_name[:-1] # remove tarsus index
+ if not model.find('default', def_name):
+ raise ValueError('Default class not found for joint.')
+ joint.dclass = model.find('default', def_name)
+
+ # remove some unit quaternions while we're here
+ if parent.quat is not None and parent.quat[0] == 1:
+ parent.quat = None
+
+ children = parent.get_children('body')
+ if children:
+ parent = children[0]
+ else:
+ parent = None
+
+ print('Abdomen tendons.')
+ abdomen_tendons = {'abduct_abdomen': None, 'extend_abdomen': None}
+ for name in abdomen_tendons:
+ tendon = model.tendon.add('fixed', name=name)
+ abdomen_tendons[name] = tendon
+ for abdomen in abdomens:
+ for joint in abdomen.get_children('joint'):
+ if name in joint.name:
+ tendon.add('joint', joint=joint, coef=1)
+
+ print('Tarsus tendons.')
+ tarsus_tendons = {}
+ for leg in legs:
+ # Tendon couples tarsi starting from tarsus2.
+ parent = leg.find('body', leg.name.replace('coxa', 'tarsus2'))
+ name = 'extend_' + leg.name.replace('coxa', 'tarsus2')
+ tendon = model.tendon.add('fixed', name=name, dclass='extend_tarsus')
+ tarsus_tendons[name] = tendon
+ while parent.get_children('joint'):
+ joint = parent.get_children('joint')[0]
+ if 'tarsus2' in joint.name:
+ coef = 1
+ else:
+ coef = .5
+ tendon.add('joint', joint=joint, coef=coef)
+ if parent.get_children('body'):
+ parent = parent.get_children('body')[0]
+ else:
+ break
+
+ print('Add "general" actuators.')
+ for joint in model.find_all('joint'):
+ if 'free' in joint.name or 'haltere' in joint.name:
+ continue
+ if 'abdomen' in joint.name:
+ for tendon_name in abdomen_tendons:
+ if joint.name == tendon_name:
+ dclass = model.find('default', tendon_name)
+ num_joints = len(abdomen_tendons[tendon_name].get_children('joint'))
+ model.actuator.add('general', name=tendon_name,
+ tendon=abdomen_tendons[tendon_name], dclass=dclass,
+ ctrlrange=num_joints*dclass.joint.range)
+ continue
+ if 'tarsus' in joint.name:
+ if 'tarsus2' in joint.name:
+ name = joint.name
+ tendon = model.find('tendon', name)
+ if tendon is not None:
+ trange = np.array((0.0, 0.0))
+ for tjoint in tendon.get_children('joint'):
+ trange += tjoint.joint.dclass.joint.range * tjoint.coef
+ model.actuator.add('general', name=name,
+ tendon=tendon, dclass=tendon.dclass,
+ ctrlrange=trange)
+ continue
+ elif 'abduct_tarsus' not in joint.name and 'tarsus_' not in joint.name:
+ continue
+ dclass = joint.dclass
+ parent = joint.parent
+ if dclass is None:
+ while parent.childclass is None:
+ parent = parent.parent
+ dclass = parent.childclass
+ if (
+ 'twist' in joint.name
+ or 'abduct' in joint.name
+ or 'extend' in joint.name
+ ):
+ if joint.range is not None:
+ jrange = joint.range
+ elif dclass.joint.range is not None:
+ jrange = dclass.joint.range
+ else:
+ jrange = dclass.parent.joint.range
+ # if 'twist_coxa_T2' in joint.name:
+ # import ipdb; ipdb.set_trace()
+ if ('coxa' in joint.name or 'femur' in joint.name or
+ 'tibia' in joint.name or 'tarsus' in joint.name):
+ assert (dclass.joint.range is not None or
+ dclass.parent.joint.range is not None)
+ actrange = None
+ else:
+ actrange = jrange
+ model.actuator.add('general', name=joint.name, joint=joint,
+ dclass=dclass, ctrlrange=actrange)
+ else: # wing joints
+ model.actuator.add('general', name=joint.name, joint=joint,
+ dclass=dclass)
+
+ print('Add "adhesion" actuators.')
+ for body in model.find_all('body'):
+ if 'claw' in body.name:
+ model.actuator.add('adhesion', name='adhere_'+body.name,
+ body=body.name, dclass='adhesion_claw')
+ if 'labrum' in body.name:
+ model.actuator.add('adhesion', name='adhere_'+body.name,
+ body=body.name, dclass='adhesion_labrum')
+
+ print('Remove abduction for tarsi and tibia.')
+ for actuator in model.find_all('actuator'):
+ if 'abduct_tarsus' in actuator.name or 'abduct_tibia' in actuator.name:
+ actuator.remove()
+ for joint in model.find_all('joint'):
+ if 'abduct_tarsus' in joint.name or 'abduct_tibia' in joint.name:
+ joint.remove()
+ for deflt in model.find_all('default'):
+ if deflt.dclass is not None:
+ if 'abduct_tarsus' in deflt.dclass or 'abduct_tibia' in deflt.dclass:
+ deflt.remove()
+
+ print('Recompile.')
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ print('Print qpos0 contacting bodies:')
+ exclude_pairs = []
+ for con in physics.data.contact:
+ body1 = physics.model.id2name(physics.model.geom_bodyid[con.geom1], 'body')
+ body2 = physics.model.id2name(physics.model.geom_bodyid[con.geom2], 'body')
+ if 'coxa' in body1 or 'coxa' in body2:
+ continue
+ print(f' {body1} and {body2} are in contact.')
+
+ print('Exclude contacts')
+ # Wing-abdomen.
+ for i in range(len(abdomens) - 2):
+ exclude_pairs.append((abdomens[i], abdomens[i+2]))
+ exclude_pairs.append((wings[0], abdomens[0]))
+ exclude_pairs.append((wings[0], abdomens[1]))
+ exclude_pairs.append((wings[0], abdomens[2]))
+ exclude_pairs.append((wings[1], abdomens[0]))
+ exclude_pairs.append((wings[1], abdomens[1]))
+ exclude_pairs.append((wings[1], abdomens[2]))
+ # Wing-wing.
+ exclude_pairs.append((wings[0], wings[1]))
+ # Coxa-coxa, coxa-femur, femur-femur.
+ for left_coxa in legs:
+ if 'right' in left_coxa.name:
+ continue
+ right_coxa_name = left_coxa.name.replace('left', 'right')
+ right_coxa = left_coxa.parent.find('body', right_coxa_name)
+ left_femur = left_coxa.get_children('body')[0]
+ right_femur = right_coxa.get_children('body')[0]
+ exclude_pairs.append((left_coxa, right_coxa))
+ exclude_pairs.append((left_coxa, right_femur))
+ exclude_pairs.append((left_femur, right_coxa))
+ exclude_pairs.append((left_femur, right_femur))
+ # rostrum-labrum.
+ for rostrum in head.find_all('body'):
+ if 'rostrum' in rostrum.name:
+ for labrum in head.find_all('body'):
+ if 'labrum' in labrum.name:
+ exclude_pairs.append((rostrum, labrum))
+
+ for pair in exclude_pairs:
+ model.contact.add('exclude', body1=pair[0], body2=pair[1],
+ name=pair[0].name + '_' + pair[1].name)
+
+ print('Re-center thorax CoM at origin, again.')
+ offset = physics.bind(thorax).xipos
+ offset[1] = 0.0
+ thorax.pos = np.array((0., 0., 0.))
+ change_body_frame(thorax, offset, np.array((1., 0., 0., 0.)))
+ if thorax.quat[0] == 1:
+ thorax.quat = None
+ thorax.pos = None
+
+ print('Add sensors.')
+ thorax.insert('site', 1, name='thorax')
+ angle = -_YAW_AXIS_PITCH
+ thorax.insert('site', 2, name='hover_up_dir',
+ quat=np.array([np.cos(angle/2), 0, np.sin(angle/2), 0]),
+ pos=(0.02625, 0, 0.02625))
+ model.sensor.add('accelerometer', name='accelerometer', site='thorax')
+ model.sensor.add('gyro', name='gyro', site='thorax')
+ model.sensor.add('velocimeter', name='velocimeter', site='thorax')
+ touch_sites = []
+ force_sites = []
+ for leg in legs:
+ for body in leg.find_all('body'):
+ if 'claw' in body.name:
+ for geom in body.get_children('geom'):
+ if 'collision' in geom.name:
+ site = body.add('site', name=body.name,
+ dclass='adhesion-collision',
+ fromto=geom.fromto,
+ size=geom.size*1.1)
+ touch_sites.append(site)
+ if 'tarsus_' in body.name:
+ site = body.insert('site', -1, name=body.name)
+ force_sites.append(site)
+ for site in force_sites:
+ model.sensor.add('force', name='force_' + site.name, site=site)
+ for site in touch_sites:
+ model.sensor.add('touch', name='touch_' + site.name, site=site)
+
+ print('Add thorax light and cameras.')
+ thorax.insert('light', 1, name='tracking', mode='trackcom', pos=(0, 0, 1))
+ thorax.insert('light', 1, name='left', mode='trackcom', pos=(0, 1, 1),
+ dir=(0, -1, -1), diffuse=(0.3, 0.3, 0.3))
+ thorax.insert('light', 1, name='right', mode='trackcom', pos=(0, -1, 1),
+ dir=(0, 1, -1), diffuse=(0.3, 0.3, 0.3))
+ thorax.insert(
+ 'camera',
+ 2,
+ name='track1',
+ mode='trackcom',
+ pos=(0.6, 0.6, 0.22),
+ quat=(0.31246, 0.22094, 0.5334, 0.75434))
+ thorax.insert(
+ 'camera',
+ 3,
+ name='track2',
+ mode='trackcom',
+ pos=(0, -1.1, 0.1),
+ quat=(0.70711, 0.70711, 0, 0))
+ thorax.insert(
+ 'camera',
+ 4,
+ name='track3',
+ mode='trackcom',
+ pos=(-0.9, -0.9, 0.9),
+ quat=(0.82047, 0.42471, -0.17592, -0.33985))
+ thorax.insert(
+ 'camera',
+ 5,
+ name='back',
+ mode='track',
+ pos=(-0.462, 0, 0.297),
+ xyaxes=(0, -1, 0, 0.707, 0, 0.707))
+ thorax.insert(
+ 'camera',
+ 6,
+ name='side',
+ mode='track',
+ pos=(-0.055, 0.424, -0.064),
+ xyaxes=(-1, 0, 0, 0, 0, 1))
+ thorax.insert(
+ 'camera',
+ 7,
+ name='bottom',
+ mode='track',
+ pos=(0.01, 0, -0.516),
+ xyaxes=(0, 1, 0, .991, 0, 0.136))
+
+ print('Check masses (values in milligrams)')
+ def print_mass(name, current, emp):
+ print(f' {name:10}\t{current:.4g}\t\t{emp:.4g}\t\t{emp/current:.4g}')
+ def print_masses():
+ print(' part\t\tmodeled\t\tempirical\tratio')
+ thorax_mass = 1e3*physics.named.model.body_mass['thorax']
+ print_mass('thorax', thorax_mass, _MASS['thorax'])
+ for part in ['head', 'abdomen']:
+ part_mass = 1e3*physics.named.model.body_subtreemass[part]
+ print_mass(part, part_mass, _MASS[part])
+ wing_mass = 0
+ for wing in ['wing_left', 'wing_right']:
+ wing_mass += 1e3*physics.named.model.body_subtreemass[wing]
+ print_mass('wing', wing_mass/2, _MASS['wing'])
+ leg_mass = 0
+ for leg in legs:
+ leg_mass += 1e3*physics.named.model.body_subtreemass[leg.name]
+ print_mass('leg', leg_mass/6, _MASS['leg'])
+
+ print_masses()
+ total_mass_model = 1e3*physics.named.model.body_subtreemass['thorax']
+ total_mass_emp = (_MASS['head'] + _MASS['thorax'] + _MASS['abdomen'] +
+ 6 * _MASS['leg'] + 2 * _MASS['wing'])
+ print(f' Total Mass\t{total_mass_model:.4g}\t\t{total_mass_emp:.4g}')
+
+ print('Change order: axis_body -> body_axis.')
+ elements = model.find_all('actuator') + model.find_all('joint')
+ for element in elements:
+ if 'adhere' in element.name:
+ continue
+ parts_nested = [s.split('-') for s in element.name.split('_')]
+ parts = list(itertools.chain.from_iterable(parts_nested))
+ if len(parts) < 2:
+ continue
+ order = list(range(len(parts)))
+ order[0] = 1
+ order[1] = 0
+ element.name = '_'.join([parts[i] for i in order])
+
+ print('Remove unnecessary "extend"s.')
+ for joint in model.find_all('joint'):
+ if '_extend' in joint.name:
+ joint.name = joint.name.replace('_extend', '')
+ for actuator in model.find_all('actuator'):
+ if '_extend' in actuator.name:
+ actuator.name = actuator.name.replace('_extend', '')
+ for tendon in model.find_all('tendon'):
+ if 'extend_' in tendon.name:
+ tendon.name = tendon.name.replace('extend_', '')
+
+ def shorten_names():
+ print('Shortening actuator names')
+ elements = model.find_all('actuator') + model.find_all('joint')
+ for element in elements:
+ parts_nested = [s.split('-') for s in element.name.split('_')]
+ parts = list(itertools.chain.from_iterable(parts_nested))
+ parts = [part.replace('left', 'L') for part in parts]
+ parts = [part.replace('right', 'R') for part in parts]
+ parts = [part.replace('extend', 'ex') for part in parts]
+ parts = [part.replace('abduct', 'ab') for part in parts]
+ parts = [part.replace('twist', 'tw') for part in parts]
+ parts = [part.replace('wing', 'wng') for part in parts]
+ if len(element.name) > 10:
+ parts = [part.replace('T1', '1') for part in parts]
+ parts = [part.replace('T2', '2') for part in parts]
+ parts = [part.replace('T3', '3') for part in parts]
+ parts = [part.replace('tarsus', 'tars') for part in parts]
+ parts = [part.replace('pitch', 'ptch') for part in parts]
+ parts = [part.replace('femur', 'fem') for part in parts]
+ parts = [part.replace('antenna', 'anten') for part in parts]
+ # parts = [part.replace('haltere', 'halt') for part in parts]
+ parts = [part.replace('haustellum', 'haust') for part in parts]
+ # parts = [part.replace('labrum', 'labr') for part in parts]
+ short_name = '_'.join(parts)
+ short_name = short_name[0].lower() + short_name[1:]
+ if short_name[-3].isdigit():
+ short_name = short_name[:-3] + short_name[-1] + short_name[-3]
+ element.name = short_name
+
+ # name-shortening: currently unused
+ if 'short' in FINAL_MODEL:
+ shorten_names()
+
+ print('== XML cleanup.')
+ print('Remove class="/" using lxml.')
+ xml_string = model.to_xml_string('float', precision=3, zero_threshold=1e-7)
+ root = etree.XML(xml_string,
+ etree.XMLParser(remove_blank_text=True))
+ default_elem = root.find('default')
+ root.insert(3, default_elem[0])
+ root.remove(default_elem)
+
+ print('Remove hashes from filenames.')
+ meshes = [mesh for mesh in root.find('asset').iter() if mesh.tag == 'mesh']
+ for mesh in meshes:
+ name, extension = mesh.get('file').split('.')
+ mesh.set('file', '.'.join((name[:-41], extension)))
+
+ print('Get string from lxml.')
+ xml_string = etree.tostring(root, pretty_print=True)
+ xml_string = xml_string.replace(b' class="/"', b'')
+
+ print('Remove gravcomp="0".')
+ xml_string = xml_string.replace(b' gravcomp="0"', b'')
+
+ print('Insert spaces between top level elements.')
+ lines = xml_string.splitlines()
+ newlines = []
+ for line in lines:
+ newlines.append(line)
+ if line.startswith(b' <'):
+ if line.startswith(b' ') or line.endswith(b'/>'):
+ newlines.append(b'')
+ newlines.append(b'')
+ xml_string = b'\n'.join(newlines)
+
+ print(f'Save {FINAL_MODEL} to file.')
+ with open(ASSET_RELPATH + '/../' + FINAL_MODEL, 'wb') as f:
+ f.write(xml_string)
+
+ print('Done.')
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/add_torque_actuators.py b/dm_control/locomotion/walkers/assets/dog_v2/add_torque_actuators.py
new file mode 100644
index 00000000..71c66fb4
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/add_torque_actuators.py
@@ -0,0 +1,85 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make torque actuators for the dog model."""
+
+import collections
+
+
+def add_motors(physics, model, lumbar_joints, cervical_joints, caudal_joints):
+ """Add torque motors in model.
+
+ Args:
+ physics: an instance of physics for the most updated model.
+ model: model in which we want to add motors.
+ lumbar_joints: a list of joints objects.
+ cervical_joints: a list of joints objects.
+ caudal_joints: a list of joints objects.
+
+ Returns:
+ A list of actuated joints.
+ """
+ # Fixed Tendons:
+ spinal_joints = collections.OrderedDict()
+ spinal_joints['lumbar_'] = lumbar_joints
+ spinal_joints['cervical_'] = cervical_joints
+ spinal_joints['caudal_'] = caudal_joints
+ tendons = []
+ for region in spinal_joints.keys():
+ for direction in ['extend', 'bend', 'twist']:
+ joints = [
+ joint for joint in spinal_joints[region] if direction in joint.name
+ ]
+ if joints:
+ tendon = model.tendon.add(
+ 'fixed', name=region + direction, dclass=joints[0].dclass
+ )
+ tendons.append(tendon)
+ joint_inertia = physics.bind(joints).M0
+ coefs = joint_inertia**0.25
+ coefs /= coefs.sum()
+ coefs *= len(joints)
+ for i, joint in enumerate(joints):
+ tendon.add('joint', joint=joint, coef=coefs[i])
+
+ # Actuators:
+ all_spinal_joints = []
+ for region in spinal_joints.values():
+ all_spinal_joints.extend(region)
+ root_joint = model.find('joint', 'root')
+ actuated_joints = [
+ joint
+ for joint in model.find_all('joint')
+ if joint not in all_spinal_joints and joint is not root_joint
+ ]
+ for tendon in tendons:
+ gain = 0.0
+ for joint in tendon.joint:
+ # joint.joint.user = physics.bind(joint.joint).damping
+ def_joint = model.default.find('default', joint.joint.dclass)
+ j_gain = def_joint.general.gainprm or def_joint.parent.general.gainprm
+ gain += j_gain[0] * joint.coef
+ gain /= len(tendon.joint)
+
+ model.actuator.add(
+ 'general', tendon=tendon, name=tendon.name, dclass=tendon.dclass
+ )
+
+ for joint in actuated_joints:
+ model.actuator.add(
+ 'general', joint=joint, name=joint.name, dclass=joint.dclass
+ )
+
+ return actuated_joints
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/build_back_legs.py b/dm_control/locomotion/walkers/assets/dog_v2/build_back_legs.py
new file mode 100644
index 00000000..bbc36e2c
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/build_back_legs.py
@@ -0,0 +1,281 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make back legs for the dog model."""
+
+from dm_control import mjcf
+import numpy as np
+
+
+def create_back_legs(
+ model,
+ primary_axis,
+ bone_position,
+ bones,
+ side_sign,
+ bone_size,
+ pelvic_bones,
+ parent,
+):
+ """Add back legs in the model.
+
+ Args:
+ model: model in which we want to add the back legs.
+ primary_axis: a dictionary of numpy arrays representing axis of rotation.
+ bone_position: a dictionary of bones positions.
+ bones: a list of strings with all the names of the bones.
+ side_sign: a dictionary with two axis representing the signs of
+ translations.
+ bone_size: dictionary containing the scale of the geometry.
+ pelvic_bones: list of string of the pelvic bones.
+ parent: parent object on which we should start attaching new components.
+
+ Returns:
+ The tuple `(nails, sole_sites)`.
+ """
+ pelvis = parent
+ # Hip joint sites:
+ scale = np.asarray([bone_size[bone] for bone in pelvic_bones]).mean()
+ hip_pos = np.array((-0.23, -0.6, -0.16)) * scale
+ for side in ['_L', '_R']:
+ pelvis.add(
+ 'site', name='hip' + side, size=[0.011], pos=hip_pos * side_sign[side]
+ )
+
+ # Upper legs:
+ upper_leg = {}
+ femurs = [b for b in bones if 'Fem' in b]
+ use_tendons = False
+ if not use_tendons:
+ femurs += [b for b in bones if 'Patella' in b]
+ for side in ['_L', '_R']:
+ body_pos = hip_pos * side_sign[side]
+ leg = pelvis.add('body', name='upper_leg' + side, pos=body_pos)
+ upper_leg[side] = leg
+ for bone in [b for b in femurs if side in b]:
+ leg.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-bone_position['Pelvis'] - body_pos,
+ dclass='bone',
+ )
+
+ # Hip joints
+ for dof in ['_supinate', '_abduct', '_extend']:
+ axis = primary_axis[dof].copy()
+ if dof != '_extend':
+ axis *= 1.0 if side != '_R' else -1.0
+ leg.add('joint', name='hip' + side + dof, dclass='hip' + dof, axis=axis)
+
+ # Knee sites
+ scale = bone_size['Femoris_L']
+ knee_pos = np.array((-0.2, -0.27, -1.45)) * scale
+ leg.add(
+ 'site',
+ type='cylinder',
+ name='knee' + side,
+ size=[0.003, 0.02],
+ zaxis=(0, 1, 0),
+ pos=knee_pos * side_sign[side],
+ )
+ pos = np.array((-0.01, -0.02, -0.08)) * side_sign[side]
+ euler = [-10 * (1.0 if side == '_R' else -1.0), 20, 0]
+ leg.add(
+ 'geom',
+ name=leg.name + '0_collision',
+ pos=pos,
+ size=[0.04, 0.08],
+ euler=euler,
+ dclass='collision_primitive',
+ )
+ pos = np.array((-0.03, 0, -0.05))
+ euler = [-10 * (1.0 if side == '_R' else -1.0), 5, 0]
+ leg.add(
+ 'geom',
+ name=leg.name + '1_collision',
+ pos=pos,
+ size=[0.04, 0.04],
+ euler=euler,
+ dclass='collision_primitive',
+ )
+
+ # Patella
+ if use_tendons:
+ # Make patella body
+ pass
+
+ # Lower legs:
+ lower_leg = {}
+ lower_leg_bones = [b for b in bones if 'Tibia_' in b or 'Fibula' in b]
+ for side in ['_L', '_R']:
+ body_pos = knee_pos * side_sign[side]
+ leg = upper_leg[side].add('body', name='lower_leg' + side, pos=body_pos)
+ lower_leg[side] = leg
+ for bone in [b for b in lower_leg_bones if side in b]:
+ leg.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-bone_position['Pelvis'] - upper_leg[side].pos - body_pos,
+ dclass='bone',
+ )
+ # Knee joints
+ leg.add('joint', name='knee' + side, dclass='knee', axis=(0, -1, 0))
+
+ # Ankle sites
+ scale = bone_size['Tibia_L']
+ ankle_pos = np.array((-1.27, 0.04, -0.98)) * scale
+ leg.add(
+ 'site',
+ type='cylinder',
+ name='ankle' + side,
+ size=[0.003, 0.013],
+ zaxis=(0, 1, 0),
+ pos=ankle_pos * side_sign[side],
+ )
+
+ # Feet:
+ foot = {}
+ foot_bones = [b for b in bones if 'tars' in b.lower() or 'tuber' in b]
+ for side in ['_L', '_R']:
+ body_pos = ankle_pos * side_sign[side]
+ leg = lower_leg[side].add('body', name='foot' + side, pos=body_pos)
+ foot[side] = leg
+ for bone in [b for b in foot_bones if side in b]:
+ leg.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-bone_position['Pelvis']
+ - upper_leg[side].pos
+ - lower_leg[side].pos
+ - body_pos,
+ dclass='bone',
+ )
+ # Ankle joints
+ leg.add('joint', name='ankle' + side, dclass='ankle', axis=(0, 1, 0))
+ pos = np.array((-0.01, -0.005, -0.05)) * side_sign[side]
+ leg.add(
+ 'geom',
+ name=leg.name + '_collision',
+ size=[0.015, 0.07],
+ pos=pos,
+ dclass='collision_primitive',
+ )
+
+ # Toe sites
+ scale = bone_size['Metatarsi_R_2']
+ toe_pos = np.array((-0.37, -0.2, -2.95)) * scale
+ leg.add(
+ 'site',
+ type='cylinder',
+ name='toe' + side,
+ size=[0.003, 0.025],
+ zaxis=(0, 1, 0),
+ pos=toe_pos * side_sign[side],
+ )
+
+ # Toes:
+ toe_bones = [b for b in bones if 'Phalange' in b]
+ toe_geoms = []
+ sole_sites = []
+ nails = []
+ for side in ['_L', '_R']:
+ body_pos = toe_pos * side_sign[side]
+ foot_anchor = foot[side].add(
+ 'body', name='foot_anchor' + side, pos=body_pos
+ )
+ foot_anchor.add(
+ 'geom',
+ name=foot_anchor.name,
+ dclass='foot_primitive',
+ type='box',
+ size=(0.005, 0.005, 0.005),
+ contype=0,
+ conaffinity=0,
+ )
+ foot_anchor.add('site', name=foot_anchor.name, dclass='sensor')
+ leg = foot_anchor.add('body', name='toe' + side)
+ for bone in [b for b in toe_bones if side in b]:
+ geom = leg.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-bone_position['Pelvis']
+ - upper_leg[side].pos
+ - lower_leg[side].pos
+ - foot[side].pos
+ - body_pos,
+ dclass='bone',
+ )
+ if 'B_3' in bone:
+ nails.append(bone)
+ geom.dclass = 'visible_bone'
+ else:
+ toe_geoms.append(geom)
+ # Toe joints
+ leg.add('joint', name='toe' + side, dclass='toe', axis=(0, 1, 0))
+ # Collision geoms
+ leg.add(
+ 'geom',
+ name=leg.name + '0_collision',
+ size=[0.018, 0.012],
+ pos=[0.015, 0, -0.02],
+ euler=(90, 0, 0),
+ dclass='foot_primitive',
+ )
+ leg.add(
+ 'geom',
+ name=leg.name + '1_collision',
+ size=[0.01, 0.015],
+ pos=[0.035, 0, -0.028],
+ euler=(90, 0, 0),
+ dclass='foot_primitive',
+ )
+ leg.add(
+ 'geom',
+ name=leg.name + '2_collision',
+ size=[0.008, 0.01],
+ pos=[0.045, 0, -0.03],
+ euler=(90, 0, 0),
+ dclass='foot_primitive',
+ )
+ sole = leg.add(
+ 'site',
+ name='sole' + side,
+ size=(0.025, 0.03, 0.008),
+ pos=(0.026, 0, -0.033),
+ type='box',
+ dclass='sensor',
+ )
+
+ sole_sites.append(sole)
+
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ for side in ['_L', '_R']:
+ # lower leg:
+ leg = lower_leg[side]
+ leg.add(
+ 'geom',
+ name=leg.name + '_collision',
+ pos=physics.bind(leg).ipos * 1.3,
+ size=[0.02, 0.1],
+ quat=physics.bind(leg).iquat,
+ dclass='collision_primitive',
+ )
+
+ return nails, sole_sites
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/build_dog.py b/dm_control/locomotion/walkers/assets/dog_v2/build_dog.py
new file mode 100644
index 00000000..ffcaea18
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/build_dog.py
@@ -0,0 +1,426 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make dog model."""
+
+import os
+
+from absl import app
+from absl import flags
+from dm_control import mjcf
+from dm_control.locomotion.walkers.assets.dog_v2 import add_torque_actuators
+from dm_control.locomotion.walkers.assets.dog_v2 import build_back_legs
+from dm_control.locomotion.walkers.assets.dog_v2 import build_front_legs
+from dm_control.locomotion.walkers.assets.dog_v2 import build_neck
+from dm_control.locomotion.walkers.assets.dog_v2 import build_tail
+from dm_control.locomotion.walkers.assets.dog_v2 import build_torso
+from dm_control.locomotion.walkers.assets.dog_v2 import create_skin
+from lxml import etree
+import numpy as np
+
+from dm_control.utils import io as resources
+
+flags.DEFINE_boolean('make_skin', True, 'Whether to make a new dog_skin.skn')
+flags.DEFINE_float(
+ 'lumbar_dofs_per_vertebra',
+ 1.5,
+ 'Number of degrees of freedom per vertebra in lumbar spine.',
+)
+flags.DEFINE_float(
+ 'cervical_dofs_per_vertebra',
+ 1.5,
+ 'Number of degrees of freedom vertebra in cervical spine.',
+)
+flags.DEFINE_float(
+ 'caudal_dofs_per_vertebra',
+ 1,
+ 'Number of degrees of freedom vertebra in caudal spine.',
+)
+
+FLAGS = flags.FLAGS
+
+BASE_MODEL = 'dog_base.xml'
+ASSET_RELPATH = '../../../../suite/dog_assets'
+ASSET_DIR = os.path.dirname(__file__) + '/' + ASSET_RELPATH
+print(ASSET_DIR)
+
+
+def exclude_contacts(model):
+ """Exclude contacts from model.
+
+ Args:
+ model: model in which we want to exclude contacts.
+ """
+ physics = mjcf.Physics.from_mjcf_model(model)
+ excluded_pairs = []
+ for c in physics.data.contact:
+ body1 = physics.model.id2name(physics.model.geom_bodyid[c.geom1], 'body')
+ body2 = physics.model.id2name(physics.model.geom_bodyid[c.geom2], 'body')
+ pair = body1 + ':' + body2
+ if pair not in excluded_pairs:
+ excluded_pairs.append(pair)
+ model.contact.add('exclude', name=pair, body1=body1, body2=body2)
+ # manual exclusions
+ model.contact.add(
+ 'exclude',
+ name='C_1:jaw',
+ body1=model.find('body', 'C_1'),
+ body2=model.find('body', 'jaw'),
+ )
+ model.contact.add(
+ 'exclude',
+ name='torso:lower_arm_L',
+ body1=model.find('body', 'torso'),
+ body2='lower_arm_L',
+ )
+ model.contact.add(
+ 'exclude',
+ name='torso:lower_arm_R',
+ body1=model.find('body', 'torso'),
+ body2='lower_arm_R',
+ )
+ model.contact.add(
+ 'exclude', name='C_4:scapula_R', body1='C_4', body2='scapula_R'
+ )
+ model.contact.add(
+ 'exclude', name='C_4:scapula_L', body1='C_4', body2='scapula_L'
+ )
+ model.contact.add(
+ 'exclude', name='C_5:upper_arm_R', body1='C_5', body2='upper_arm_R'
+ )
+ model.contact.add(
+ 'exclude', name='C_5:upper_arm_L', body1='C_5', body2='upper_arm_L'
+ )
+ model.contact.add(
+ 'exclude', name='C_6:upper_arm_R', body1='C_6', body2='upper_arm_R'
+ )
+ model.contact.add(
+ 'exclude', name='C_6:upper_arm_L', body1='C_6', body2='upper_arm_L'
+ )
+ model.contact.add(
+ 'exclude', name='C_7:upper_arm_R', body1='C_7', body2='upper_arm_R'
+ )
+ model.contact.add(
+ 'exclude', name='C_7:upper_arm_L', body1='C_7', body2='upper_arm_L'
+ )
+ model.contact.add(
+ 'exclude',
+ name='upper_leg_L:upper_leg_R',
+ body1='upper_leg_L',
+ body2='upper_leg_R',
+ )
+ for side in ['_L', '_R']:
+ model.contact.add(
+ 'exclude',
+ name='lower_leg' + side + ':pelvis',
+ body1='lower_leg' + side,
+ body2='pelvis',
+ )
+ model.contact.add(
+ 'exclude',
+ name='upper_leg' + side + ':foot' + side,
+ body1='upper_leg' + side,
+ body2='foot' + side,
+ )
+
+
+def main(argv):
+ del argv
+
+ # Read flags.
+ if FLAGS.is_parsed():
+ lumbar_dofs_per_vert = FLAGS.lumbar_dofs_per_vertebra
+ cervical_dofs_per_vertebra = FLAGS.cervical_dofs_per_vertebra
+ caudal_dofs_per_vertebra = FLAGS.caudal_dofs_per_vertebra
+ make_skin = FLAGS.make_skin
+ else:
+ lumbar_dofs_per_vert = FLAGS['lumbar_dofs_per_vertebra'].default
+ cervical_dofs_per_vertebra = FLAGS['cervical_dofs_per_vertebra'].default
+ caudal_dofs_per_vertebra = FLAGS['caudal_dofs_per_vertebra'].default
+ make_skin = FLAGS['make_skin'].default
+
+ print('Load base model.')
+ with open(BASE_MODEL, 'r') as f:
+ model = mjcf.from_file(f)
+
+ # Helper constants:
+ side_sign = {
+ '_L': np.array((1.0, -1.0, 1.0)),
+ '_R': np.array((1.0, 1.0, 1.0)),
+ }
+ primary_axis = {
+ '_abduct': np.array((-1.0, 0.0, 0.0)),
+ '_extend': np.array((0.0, 1.0, 0.0)),
+ '_supinate': np.array((0.0, 0.0, -1.0)),
+ }
+
+ # Add meshes:
+ print('Loading all meshes, getting positions and sizes.')
+ meshdir = ASSET_DIR
+ model.compiler.meshdir = meshdir
+ texturedir = ASSET_DIR
+ model.compiler.texturedir = texturedir
+ bones = []
+ for dirpath, _, filenames in resources.WalkResources(meshdir):
+ prefix = 'extras/' if 'extras' in dirpath else ''
+ for filename in filenames:
+ if 'dog_skin.msh' in filename:
+ skin_msh = model.asset.add(
+ 'mesh', name='skin_msh', file=filename, scale=(1.25, 1.25, 1.25)
+ )
+ name = filename[4:-4]
+ name = name.replace('*', ':')
+ if filename.startswith('BONE'):
+ if 'Lingual' not in name:
+ bones.append(name)
+ model.asset.add('mesh', name=name, file=prefix + filename)
+
+ # Put all bones in worldbody, get positions, remove bones:
+ bone_geoms = []
+ for bone in bones:
+ geom = model.worldbody.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ type='mesh',
+ contype=0,
+ conaffinity=0,
+ rgba=[1, 0.5, 0.5, 0.4],
+ )
+ bone_geoms.append(geom)
+ physics = mjcf.Physics.from_mjcf_model(model)
+ bone_position = {}
+ bone_size = {}
+ for bone in bones:
+ geom = model.find('geom', bone)
+ bone_position[bone] = np.array(physics.bind(geom).xpos)
+ bone_size[bone] = np.array(physics.bind(geom).rbound)
+ geom.remove()
+
+ # Torso
+ print('Torso, lumbar spine, pelvis.')
+ pelvic_bones, lumbar_joints = build_torso.create_torso(
+ model,
+ bones,
+ bone_position,
+ lumbar_dofs_per_vert,
+ side_sign,
+ parent=model.worldbody,
+ )
+
+ print('Neck, skull, jaw.')
+ # Cervical spine (neck) bodies:
+ cervical_joints = build_neck.create_neck(
+ model,
+ bone_position,
+ cervical_dofs_per_vertebra,
+ bones,
+ side_sign,
+ bone_size,
+ parent=model.find('body', 'torso'),
+ )
+
+ print('Back legs.')
+ nails, sole_sites = build_back_legs.create_back_legs(
+ model,
+ primary_axis,
+ bone_position,
+ bones,
+ side_sign,
+ bone_size,
+ pelvic_bones,
+ parent=model.find('body', 'pelvis'),
+ )
+
+ print('Shoulders, front legs.')
+ palm_sites = build_front_legs.create_front_legs(
+ nails,
+ model,
+ primary_axis,
+ bones,
+ side_sign,
+ parent=model.find('body', 'torso'),
+ )
+
+ print('Tail.')
+ caudal_joints = build_tail.create_tail(
+ caudal_dofs_per_vertebra,
+ bone_size,
+ model,
+ bone_position,
+ parent=model.find('body', 'pelvis'),
+ )
+
+ print('Collision geoms, fixed tendons.')
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ print('Unify ribcage and jaw meshes.')
+ for body in model.find_all('body'):
+ body_meshes = [
+ geom
+ for geom in body.all_children()
+ if geom.tag == 'geom'
+ and hasattr(geom, 'mesh')
+ and geom.mesh is not None
+ ]
+ if len(body_meshes) > 10:
+ mergables = [
+ ('torso', 'Ribcage'),
+ ('jaw', 'Jaw'),
+ ('skull', 'MergedSkull'),
+ ]
+ for bodyname, meshname in mergables:
+ if body.name == bodyname:
+ print('==== Merging ', bodyname)
+ for mesh in body_meshes:
+ print(mesh.name)
+ body.add(
+ 'inertial',
+ mass=physics.bind(body).mass,
+ pos=physics.bind(body).ipos,
+ quat=physics.bind(body).iquat,
+ diaginertia=physics.bind(body).inertia,
+ )
+
+ for mesh in body_meshes:
+ if 'eye' not in mesh.name:
+ model.find('mesh', mesh.name).remove()
+ mesh.remove()
+ body.add(
+ 'geom',
+ name=meshname,
+ mesh=meshname,
+ dclass='bone',
+ pos=-bone_position[meshname],
+ )
+
+ print('Add Actuators')
+ actuated_joints = add_torque_actuators.add_motors(
+ physics, model, lumbar_joints, cervical_joints, caudal_joints
+ )
+
+ print('Excluding contacts.')
+ exclude_contacts(model)
+
+ if make_skin:
+ create_skin.create(model, skin_msh)
+
+ # Add skin from .skn
+ print('Adding Skin.')
+ skin_texture = model.asset.add(
+ 'texture', name='skin', file='skin_texture.png', type='2d'
+ )
+ skin_material = model.asset.add('material', name='skin', texture=skin_texture)
+ model.asset.add(
+ 'skin', name='skin', file='dog_skin.skn', material=skin_material
+ )
+ skin_msh.remove()
+
+ print('Removing non-essential sites.')
+ all_sites = model.find_all('site')
+ for site in all_sites:
+ if site.dclass is None:
+ site.remove()
+
+ physics = mjcf.Physics.from_mjcf_model(model)
+ # sensors
+ model.sensor.add(
+ 'accelerometer', name='accelerometer', site=model.find('site', 'head')
+ )
+ model.sensor.add(
+ 'velocimeter', name='velocimeter', site=model.find('site', 'head')
+ )
+ model.sensor.add('gyro', name='gyro', site=model.find('site', 'head'))
+ model.sensor.add(
+ 'subtreelinvel', name='torso_linvel', body=model.find('body', 'torso')
+ )
+ model.sensor.add(
+ 'subtreeangmom', name='torso_angmom', body=model.find('body', 'torso')
+ )
+ for site in palm_sites + sole_sites:
+ model.sensor.add('touch', name=site.name, site=site)
+ anchors = [site for site in model.find_all('site') if 'anchor' in site.name]
+ for site in anchors:
+ model.sensor.add('force', name=site.name.replace('_anchor', ''), site=site)
+
+ # Print stuff
+ joint_acts = [model.find('actuator', j.name) for j in actuated_joints]
+ print(
+ '{:20} {:>10} {:>10} {:>10} {:>10} {:>10}'.format(
+ 'name', 'mass', 'damping', 'stiffness', 'ratio', 'armature'
+ )
+ )
+ for i, j in enumerate(actuated_joints):
+ dmp = physics.bind(j).damping[0]
+ mass_eff = physics.bind(j).M0[0]
+ dmp = physics.bind(j).damping[0]
+ stf = physics.bind(joint_acts[i]).gainprm[0]
+ arma = physics.bind(j).armature[0]
+ print(
+ '{:20} {:10.4} {:10} {:10.4} {:10.4} {:10}'.format(
+ j.name,
+ mass_eff,
+ dmp,
+ stf,
+ dmp / (2 * np.sqrt(mass_eff * stf)),
+ arma,
+ )
+ )
+
+ print('Finalising and saving model.')
+ xml_string = model.to_xml_string('float', precision=4, zero_threshold=1e-7)
+ root = etree.XML(xml_string, etree.XMLParser(remove_blank_text=True))
+
+ print('Remove hashes from filenames')
+ assets = list(root.find('asset').iter())
+ for asset in assets:
+ asset_filename = asset.get('file')
+ if asset_filename is not None:
+ name = asset_filename[:-4]
+ extension = asset_filename[-4:]
+ asset.set('file', name[:-41] + extension)
+
+ print('Add , for locally-loadable model')
+ compiler = etree.Element(
+ 'compiler', meshdir=ASSET_RELPATH, texturedir=ASSET_RELPATH
+ )
+ root.insert(0, compiler)
+
+ print('Remove class="/"')
+ default_elem = root.find('default')
+ root.insert(6, default_elem[0])
+ root.remove(default_elem)
+ xml_string = etree.tostring(root, pretty_print=True)
+ xml_string = xml_string.replace(b' class="/"', b'')
+
+ print('Insert spaces between top level elements')
+ lines = xml_string.splitlines()
+ newlines = []
+ for line in lines:
+ newlines.append(line)
+ if line.startswith(b' <'):
+ if line.startswith(b' ') or line.endswith(b'/>'):
+ newlines.append(b'')
+ newlines.append(b'')
+ xml_string = b'\n'.join(newlines)
+
+ # Save to file.
+ f = open('dog.xml', 'wb')
+ f.write(xml_string)
+ f.close()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/build_front_legs.py b/dm_control/locomotion/walkers/assets/dog_v2/build_front_legs.py
new file mode 100644
index 00000000..e059da07
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/build_front_legs.py
@@ -0,0 +1,336 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make front legs for the dog model."""
+
+from dm_control import mjcf
+import numpy as np
+
+
+def create_front_legs(nails, model, primary_axis, bones, side_sign, parent):
+ """Add front legs in the model.
+
+ Args:
+ nails: a list of string with the geoms representing nails.
+ model: model in which we want to add the front legs.
+ primary_axis: a dictionary of numpy arrays representing axis of rotation.
+ bones: a list of strings with all the names of the bones.
+ side_sign: a dictionary with two axis representing the signs of
+ translations.
+ parent: parent object on which we should start attaching new components.
+
+ Returns:
+ A list of palm sites.
+ """
+ def_scapula_supinate = model.default.find('default', 'scapula_supinate')
+ def_scapula_abduct = model.default.find('default', 'scapula_abduct')
+ def_scapula_extend = model.default.find('default', 'scapula_extend')
+
+ scapula_defaults = {
+ '_abduct': def_scapula_abduct,
+ '_extend': def_scapula_extend,
+ '_supinate': def_scapula_supinate,
+ }
+
+ torso = parent
+ # Shoulders
+ scapula = {}
+ scapulae = [b for b in bones if 'Scapula' in b]
+ scapula_pos = np.array((0.08, -0.02, 0.14))
+ for side in ['_L', '_R']:
+ body_pos = scapula_pos * side_sign[side]
+ arm = torso.add('body', name='scapula' + side, pos=body_pos)
+ scapula[side] = arm
+ for bone in [b for b in scapulae if side in b]:
+ arm.add(
+ 'geom', name=bone, mesh=bone, pos=-torso.pos - body_pos, dclass='bone'
+ )
+
+ # Shoulder joints
+ for dof in ['_supinate', '_abduct', '_extend']:
+ joint_axis = primary_axis[dof].copy()
+ joint_pos = scapula_defaults[dof].joint.pos.copy()
+ if dof != '_extend':
+ joint_axis *= 1.0 if side == '_R' else -1.0
+ joint_pos *= side_sign[side]
+ else:
+ joint_axis += (
+ 0.3 * (1 if side == '_R' else -1) * primary_axis['_abduct']
+ )
+ arm.add(
+ 'joint',
+ name='scapula' + side + dof,
+ dclass='scapula' + dof,
+ axis=joint_axis,
+ pos=joint_pos,
+ )
+
+ # Shoulder sites
+ shoulder_pos = np.array((0.075, -0.033, -0.13))
+ arm.add(
+ 'site',
+ name='shoulder' + side,
+ size=[0.01],
+ pos=shoulder_pos * side_sign[side],
+ )
+
+ # Upper Arms:
+ upper_arm = {}
+ parent_pos = {}
+ humeri = ['humerus_R', 'humerus_L']
+ for side in ['_L', '_R']:
+ body_pos = shoulder_pos * side_sign[side]
+ parent = scapula[side]
+ parent_pos[side] = torso.pos + parent.pos
+ arm = parent.add('body', name='upper_arm' + side, pos=body_pos)
+ upper_arm[side] = arm
+ for bone in [b for b in humeri if side in b]:
+ arm.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-parent_pos[side] - body_pos,
+ dclass='bone',
+ )
+ parent_pos[side] += body_pos
+
+ # Shoulder joints
+ for dof in ['_supinate', '_extend']:
+ joint_axis = primary_axis[dof].copy()
+ if dof == '_supinate':
+ joint_axis[0] = 1
+ joint_axis *= 1.0 if side == '_R' else -1.0
+ arm.add(
+ 'joint',
+ name='shoulder' + side + dof,
+ dclass='shoulder' + dof,
+ axis=joint_axis,
+ )
+
+ # Elbow sites
+ elbow_pos = np.array((-0.05, -0.015, -0.145))
+ arm.add(
+ 'site',
+ type='cylinder',
+ name='elbow' + side,
+ size=[0.003, 0.02],
+ zaxis=(0, 1, -(1.0 if side == '_R' else -1.0) * 0.2),
+ pos=elbow_pos * side_sign[side],
+ )
+
+ # Lower arms:
+ lower_arm = {}
+ lower_arm_bones = [
+ b
+ for b in bones
+ if 'ulna' in b.lower() or 'Radius' in b or 'accessory' in b
+ ]
+ for side in ['_L', '_R']:
+ body_pos = elbow_pos * side_sign[side]
+ arm = upper_arm[side].add('body', name='lower_arm' + side, pos=body_pos)
+ lower_arm[side] = arm
+ for bone in [b for b in lower_arm_bones if side in b]:
+ arm.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-parent_pos[side] - body_pos,
+ dclass='bone',
+ )
+ # Elbow joints
+ elbow_axis = upper_arm[side].find_all('site')[0].zaxis
+ arm.add('joint', name='elbow' + side, dclass='elbow', axis=elbow_axis)
+ parent_pos[side] += body_pos
+
+ # Wrist sites
+ wrist_pos = np.array((0.003, 0.015, -0.19))
+ arm.add(
+ 'site',
+ type='cylinder',
+ name='wrist' + side,
+ size=[0.004, 0.017],
+ zaxis=(0, 1, 0),
+ pos=wrist_pos * side_sign[side],
+ )
+
+ # Hands:
+ hands = {}
+ hand_bones = [
+ b
+ for b in bones
+ if ('carpal' in b.lower() and 'acces' not in b and 'ulna' not in b)
+ or ('distalis_digiti_I_' in b)
+ ]
+ for side in ['_L', '_R']:
+ body_pos = wrist_pos * side_sign[side]
+ hand_anchor = lower_arm[side].add(
+ 'body', name='hand_anchor' + side, pos=body_pos
+ )
+ hand_anchor.add(
+ 'geom',
+ name=hand_anchor.name,
+ dclass='foot_primitive',
+ type='box',
+ size=(0.005, 0.005, 0.005),
+ contype=0,
+ conaffinity=0,
+ )
+ hand_anchor.add('site', name=hand_anchor.name, dclass='sensor')
+ hand = hand_anchor.add('body', name='hand' + side)
+ hands[side] = hand
+ for bone in [b for b in hand_bones if side in b]:
+ hand.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-parent_pos[side] - body_pos,
+ dclass='bone',
+ )
+ # Wrist joints
+ hand.add('joint', name='wrist' + side, dclass='wrist', axis=(0, -1, 0))
+ hand.add(
+ 'geom',
+ name=hand.name + '_collision',
+ size=[0.03, 0.016, 0.012],
+ pos=[0.01, 0, -0.04],
+ euler=(0, 65, 0),
+ dclass='collision_primitive',
+ type='box',
+ )
+
+ parent_pos[side] += body_pos
+
+ # Finger sites
+ finger_pos = np.array((0.02, 0, -0.06))
+ hand.add(
+ 'site',
+ type='cylinder',
+ name='finger' + side,
+ size=[0.003, 0.025],
+ zaxis=((1.0 if side == '_R' else -1.0) * 0.2, 1, 0),
+ pos=finger_pos * side_sign[side],
+ )
+
+ # Fingers:
+ finger_bones = [
+ b for b in bones if 'Phalanx' in b and 'distalis_digiti_I_' not in b
+ ]
+ palm_sites = []
+ for side in ['_L', '_R']:
+ body_pos = finger_pos * side_sign[side]
+ hand = hands[side].add('body', name='finger' + side, pos=body_pos)
+ for bone in [b for b in finger_bones if side in b]:
+ geom = hand.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-parent_pos[side] - body_pos,
+ dclass='bone',
+ )
+ if 'distalis' in bone:
+ nails.append(bone)
+ geom.dclass = 'visible_bone'
+ # Finger joints
+ finger_axis = upper_arm[side].find_all('site')[0].zaxis
+ hand.add('joint', name='finger' + side, dclass='finger', axis=finger_axis)
+ hand.add(
+ 'geom',
+ name=hand.name + '0_collision',
+ size=[0.018, 0.012],
+ pos=[0.012, 0, -0.012],
+ euler=(90, 0, 0),
+ dclass='foot_primitive',
+ )
+ hand.add(
+ 'geom',
+ name=hand.name + '1_collision',
+ size=[0.01, 0.015],
+ pos=[0.032, 0, -0.02],
+ euler=(90, 0, 0),
+ dclass='foot_primitive',
+ )
+ hand.add(
+ 'geom',
+ name=hand.name + '2_collision',
+ size=[0.008, 0.01],
+ pos=[0.042, 0, -0.022],
+ euler=(90, 0, 0),
+ dclass='foot_primitive',
+ )
+
+ palm = hand.add(
+ 'site',
+ name='palm' + side,
+ size=(0.028, 0.03, 0.007),
+ pos=(0.02, 0, -0.024),
+ type='box',
+ dclass='sensor',
+ )
+ palm_sites.append(palm)
+
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ for side in ['_L', '_R']:
+ # scapula:
+ scap = scapula[side]
+ geom = scap.get_children('geom')[0]
+ bound_geom = physics.bind(geom)
+ scap.add(
+ 'geom',
+ name=geom.name + '_collision',
+ pos=bound_geom.pos,
+ size=bound_geom.size * 0.8,
+ quat=bound_geom.quat,
+ type='box',
+ dclass='collision_primitive',
+ )
+ # upper arm:
+ arm = upper_arm[side]
+ geom = arm.get_children('geom')[0]
+ bound_geom = physics.bind(geom)
+ arm.add(
+ 'geom',
+ name=geom.name + '_collision',
+ pos=bound_geom.pos,
+ size=[0.02, 0.08],
+ quat=bound_geom.quat,
+ dclass='collision_primitive',
+ )
+
+ all_geoms = model.find_all('geom')
+ for geom in all_geoms:
+ if 'Ulna' in geom.name:
+ bound_geom = physics.bind(geom)
+ geom.parent.add(
+ 'geom',
+ name=geom.name + '_collision',
+ pos=bound_geom.pos,
+ size=[0.015, 0.06],
+ quat=bound_geom.quat,
+ dclass='collision_primitive',
+ )
+ if 'Radius' in geom.name:
+ bound_geom = physics.bind(geom)
+ pos = bound_geom.pos + np.array((-0.005, 0.0, -0.01))
+ geom.parent.add(
+ 'geom',
+ name=geom.name + '_collision',
+ pos=pos,
+ size=[0.017, 0.09],
+ quat=bound_geom.quat,
+ dclass='collision_primitive',
+ )
+
+ return palm_sites
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/build_neck.py b/dm_control/locomotion/walkers/assets/dog_v2/build_neck.py
new file mode 100644
index 00000000..8435796a
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/build_neck.py
@@ -0,0 +1,328 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make neck for the dog model."""
+
+import collections
+
+from dm_control import mjcf
+import numpy as np
+
+
+def create_neck(
+ model,
+ bone_position,
+ cervical_dofs_per_vertebra,
+ bones,
+ side_sign,
+ bone_size,
+ parent,
+):
+ """Add neck and head in the dog model.
+
+ Args:
+ model: model in which we want to add the neck.
+ bone_position: a dictionary of bones positions.
+ cervical_dofs_per_vertebra: a number that the determines how many dofs are
+ going to be used between each pair of cervical vetebrae.
+ bones: a list of strings with all the names of the bones.
+ side_sign: a dictionary with two axis representing the signs of
+ translations.
+ bone_size: dictionary containing the scale of the geometry.
+ parent: parent object on which we should start attaching new components.
+
+ Returns:
+ A list of cervical joints.
+ """
+ # Cervical Spine
+ def_cervical = model.default.find('default', 'cervical')
+ def_cervical_extend = model.default.find('default', 'cervical_extend')
+ def_cervical_bend = model.default.find('default', 'cervical_bend')
+ def_cervical_twist = model.default.find('default', 'cervical_twist')
+ cervical_defaults = {
+ 'extend': def_cervical_extend,
+ 'bend': def_cervical_bend,
+ 'twist': def_cervical_twist,
+ }
+
+ cervical_bones = ['C_' + str(i) for i in range(7, 0, -1)]
+ parent_pos = parent.pos
+ cervical_bodies = []
+ cervical_geoms = []
+ radius = 0.07
+ for i, bone in enumerate(cervical_bones):
+ bone_pos = bone_position[bone]
+ rel_pos = bone_pos - parent_pos
+ child = parent.add('body', name=bone, pos=rel_pos)
+ cervical_bodies.append(child)
+ dclass = 'bone' if i > 3 else 'light_bone'
+ geom = child.add('geom', name=bone, mesh=bone, pos=-bone_pos, dclass=dclass)
+ child.add(
+ 'geom',
+ name=bone + '_collision',
+ type='sphere',
+ size=[radius],
+ dclass='nonself_collision_primitive',
+ )
+ radius -= 0.006
+ cervical_geoms.append(geom)
+ parent = child
+ parent_pos = bone_pos
+
+ # Reload
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ # Cervical (neck) spine joints:
+ cervical_axis = collections.OrderedDict()
+ cervical_axis['extend'] = np.array((0.0, 1.0, 0.0))
+ cervical_axis['bend'] = np.array((0.0, 0.0, 1.0))
+ cervical_axis['twist'] = np.array((1.0, 0.0, 0))
+
+ num_dofs = 0
+ cervical_joints = []
+ cervical_joint_names = []
+ torso = model.find('body', 'torso')
+ parent = torso.find('geom', 'T_1')
+ for i, vertebra in enumerate(cervical_bodies):
+ while num_dofs < (i + 1) * cervical_dofs_per_vertebra:
+ dof = num_dofs % 3
+ dof_name = list(cervical_axis.keys())[dof]
+ cervical_joint_names.append(vertebra.name + '_' + dof_name)
+
+ rel_pos = physics.bind(vertebra).xpos - physics.bind(parent).xpos
+ twist_dir = rel_pos / np.linalg.norm(rel_pos)
+ bend_dir = np.cross(twist_dir, cervical_axis['extend'])
+ cervical_axis['bend'] = bend_dir
+ cervical_axis['twist'] = twist_dir
+ joint_frame = np.vstack((twist_dir, cervical_axis['extend'], bend_dir))
+ joint_pos = (
+ def_cervical.joint.pos
+ * physics.bind(vertebra.find('geom', vertebra.name)).size.mean()
+ )
+ joint = vertebra.add(
+ 'joint',
+ name=cervical_joint_names[-1],
+ dclass='cervical_' + dof_name,
+ axis=cervical_axis[dof_name],
+ pos=joint_pos.dot(joint_frame),
+ )
+ cervical_joints.append(joint)
+ num_dofs += 1
+ parent = vertebra
+
+ # Lumbar spine joints:
+ lumbar_axis = collections.OrderedDict()
+ lumbar_axis['extend'] = np.array((0.0, 1.0, 0.0))
+ lumbar_axis['bend'] = np.array((0.0, 0.0, 1.0))
+ lumbar_axis['twist'] = np.array((1.0, 0.0, 0))
+
+ # Scale joint defaults relative to 3 cervical_dofs_per_vertebra
+ for dof in lumbar_axis.keys():
+ axis_scale = 7.0 / [dof in joint for joint in cervical_joint_names].count(
+ True
+ )
+ cervical_defaults[dof].joint.range *= axis_scale
+
+ # Reload
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ # Skull
+ c_1 = cervical_bodies[-1]
+ upper_teeth = [m for m in bones if 'Top' in m]
+ skull_bones = upper_teeth + ['Skull', 'Ethmoid', 'Vomer', 'eye_L', 'eye_R']
+ skull = c_1.add(
+ 'body', name='skull', pos=bone_position['Skull'] - physics.bind(c_1).xpos
+ )
+ skull_geoms = []
+ for bone in skull_bones:
+ geom = skull.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-bone_position['Skull'],
+ dclass='light_bone',
+ )
+ if 'eye' in bone:
+ geom.rgba = [1, 1, 1, 1]
+ geom.dclass = 'visible_bone'
+ skull_geoms.append(geom)
+ if bone in upper_teeth:
+ geom.dclass = 'visible_bone'
+
+ for side in ['_L', '_R']:
+ pos = np.array((0.023, -0.027, 0.01)) * side_sign[side]
+ skull.add(
+ 'geom',
+ name='iris' + side,
+ type='ellipsoid',
+ dclass='visible_bone',
+ rgba=(0.45, 0.45, 0.225, 0.4),
+ size=(0.003, 0.007, 0.007),
+ pos=pos,
+ euler=[0, 0, -20 * (1.0 if side == '_R' else -1.0)],
+ )
+ pos = np.array((0.0215, -0.0275, 0.01)) * side_sign[side]
+ skull.add(
+ 'geom',
+ name='pupil' + side,
+ type='sphere',
+ dclass='visible_bone',
+ rgba=(0, 0, 0, 1),
+ size=(0.003, 0, 0),
+ pos=pos,
+ )
+
+ # collision geoms
+ skull.add(
+ 'geom',
+ name='skull0' + '_collision',
+ type='ellipsoid',
+ dclass='collision_primitive',
+ size=(0.06, 0.06, 0.04),
+ pos=(-0.02, 0, 0.01),
+ euler=[0, 10, 0],
+ )
+ skull.add(
+ 'geom',
+ name='skull1' + '_collision',
+ type='capsule',
+ dclass='collision_primitive',
+ size=(0.015, 0.04, 0.015),
+ pos=(0.06, 0, -0.01),
+ euler=[0, 110, 0],
+ )
+ skull.add(
+ 'geom',
+ name='skull2' + '_collision',
+ type='box',
+ dclass='collision_primitive',
+ size=(0.03, 0.028, 0.008),
+ pos=(0.02, 0, -0.03),
+ )
+ skull.add(
+ 'geom',
+ name='skull3' + '_collision',
+ type='box',
+ dclass='collision_primitive',
+ size=(0.02, 0.018, 0.006),
+ pos=(0.07, 0, -0.03),
+ )
+ skull.add(
+ 'geom',
+ name='skull4' + '_collision',
+ type='box',
+ dclass='collision_primitive',
+ size=(0.005, 0.015, 0.004),
+ pos=(0.095, 0, -0.03),
+ )
+
+ skull.add(
+ 'joint',
+ name='atlas',
+ dclass='atlas',
+ pos=np.array((-0.5, 0, 0)) * bone_size['Skull'],
+ )
+
+ skull.add(
+ 'site', name='head', size=(0.01, 0.01, 0.01), type='box', dclass='sensor'
+ )
+ skull.add(
+ 'site',
+ name='upper_bite',
+ size=(0.005,),
+ dclass='sensor',
+ pos=(0.065, 0, -0.07),
+ )
+ # Jaw
+ lower_teeth = [m for m in bones if 'Bottom' in m]
+ jaw_bones = lower_teeth + ['Mandible']
+ jaw = skull.add(
+ 'body', name='jaw', pos=bone_position['Mandible'] - bone_position['Skull']
+ )
+ jaw_geoms = []
+ for bone in jaw_bones:
+ geom = jaw.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-bone_position['Mandible'],
+ dclass='light_bone',
+ )
+ jaw_geoms.append(geom)
+ if bone in lower_teeth:
+ geom.dclass = 'visible_bone'
+ # Jaw collision geoms:
+ jaw_col_pos = [
+ (-0.03, 0, 0.01),
+ (0, 0, -0.012),
+ (0.03, 0, -0.028),
+ (0.052, 0, -0.035),
+ ]
+ jaw_col_size = [
+ (0.03, 0.028, 0.008),
+ (0.02, 0.022, 0.005),
+ (0.02, 0.018, 0.005),
+ (0.015, 0.013, 0.003),
+ ]
+ jaw_col_angle = [55, 30, 25, 15]
+ for i in range(4):
+ jaw.add(
+ 'geom',
+ name='jaw' + str(i) + '_collision',
+ type='box',
+ dclass='collision_primitive',
+ size=jaw_col_size[i],
+ pos=jaw_col_pos[i],
+ euler=[0, jaw_col_angle[i], 0],
+ )
+
+ jaw.add(
+ 'joint',
+ name='mandible',
+ dclass='mandible',
+ axis=[0, 1, 0],
+ pos=np.array((-0.043, 0, 0.05)),
+ )
+ jaw.add(
+ 'site',
+ name='lower_bite',
+ size=(0.005,),
+ dclass='sensor',
+ pos=(0.063, 0, 0.005),
+ )
+
+ print('Make collision ellipsoids for teeth.')
+ visible_bones = upper_teeth + lower_teeth
+ for bone in visible_bones:
+ bone_geom = torso.find('geom', bone)
+ bone_geom.type = 'ellipsoid'
+ physics = mjcf.Physics.from_mjcf_model(model)
+ for bone in visible_bones:
+ bone_geom = torso.find('geom', bone)
+ pos = physics.bind(bone_geom).pos
+ quat = physics.bind(bone_geom).quat
+ size = physics.bind(bone_geom).size
+ bone_geom.parent.add(
+ 'geom',
+ name=bone + '_collision',
+ dclass='tooth_primitive',
+ pos=pos,
+ size=size * 1.2,
+ quat=quat,
+ type='ellipsoid',
+ )
+ bone_geom.type = None
+
+ return cervical_joints
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/build_tail.py b/dm_control/locomotion/walkers/assets/dog_v2/build_tail.py
new file mode 100644
index 00000000..40740bda
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/build_tail.py
@@ -0,0 +1,110 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make tail for the dog model."""
+
+import collections
+
+from dm_control import mjcf
+import numpy as np
+
+
+def create_tail(
+ caudal_dofs_per_vertebra, bone_size, model, bone_position, parent
+):
+ """Add tail in the dog model.
+
+ Args:
+ caudal_dofs_per_vertebra: a number that the determines how many dofs are
+ going to be used between each pair of caudal vetebrae.
+ bone_size: dictionary containing the scale of the geometry.
+ model: model in which we want to add the tail.
+ bone_position: a dictionary of bones positions.
+ parent: parent object on which we should start attaching new components.
+
+ Returns:
+ A list of caudal joints.
+ """
+ # Caudal spine (tail) bodies:
+ caudal_bones = ['Ca_' + str(i + 1) for i in range(21)]
+ parent_pos = bone_position['Pelvis']
+ caudal_bodies = []
+ caudal_geoms = []
+ for bone in caudal_bones:
+ bone_pos = bone_position[bone]
+ rel_pos = bone_pos - parent_pos
+ xyaxes = np.hstack((-rel_pos, (0, 1, 0)))
+ xyaxes[1] = 0
+ child = parent.add('body', name=bone, pos=rel_pos)
+ caudal_bodies.append(child)
+ geom = child.add('geom', name=bone, mesh=bone, pos=-bone_pos, dclass='bone')
+ caudal_geoms.append(geom)
+ parent = child
+ parent_pos = bone_pos
+
+ # Reload
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ # Caudal spine joints:
+ caudal_axis = collections.OrderedDict()
+ caudal_axis['extend'] = np.array((0.0, 1.0, 0.0))
+
+ scale = np.asarray([bone_size[bone] for bone in caudal_bones]).mean()
+ joint_pos = np.array((0.3, 0, 0.26)) * scale
+ num_dofs = 0
+ caudal_joints = []
+ caudal_joint_names = []
+ parent = model.find('geom', 'Sacrum')
+ for i, vertebra in enumerate(caudal_bodies):
+ while num_dofs < (i + 1) * caudal_dofs_per_vertebra:
+ dof = num_dofs % 2
+ dof_name = list(caudal_axis.keys())[dof]
+ caudal_joint_names.append(vertebra.name + '_' + dof_name)
+ rel_pos = physics.bind(parent).xpos - physics.bind(vertebra).xpos
+ twist_dir = rel_pos / np.linalg.norm(rel_pos)
+ bend_dir = np.cross(caudal_axis['extend'], twist_dir)
+ caudal_axis['bend'] = bend_dir
+ joint_pos = twist_dir * physics.bind(caudal_geoms[i]).size[2]
+
+ joint = vertebra.add(
+ 'joint',
+ name=caudal_joint_names[-1],
+ dclass='caudal_' + dof_name,
+ axis=caudal_axis[dof_name],
+ pos=joint_pos,
+ )
+ caudal_joints.append(joint)
+ num_dofs += 1
+ parent = vertebra
+
+ parent.add('site', name='tail_tip', dclass='sensor', size=(0.005,))
+
+ physics = mjcf.Physics.from_mjcf_model(model)
+ all_geoms = model.find_all('geom')
+
+ for geom in all_geoms:
+ if 'Ca_' in geom.name:
+ sc = (float(geom.name[3:]) + 1) / 4
+ scale = np.array((1.2, sc, 1))
+ bound_geom = physics.bind(geom)
+ geom.parent.add(
+ 'geom',
+ name=geom.name + '_collision',
+ pos=bound_geom.pos,
+ size=bound_geom.size * scale,
+ quat=bound_geom.quat,
+ dclass='collision_primitive',
+ )
+ return caudal_joints
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/build_torso.py b/dm_control/locomotion/walkers/assets/dog_v2/build_torso.py
new file mode 100644
index 00000000..dcaf3bba
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/build_torso.py
@@ -0,0 +1,181 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make torso for the dog model."""
+
+import collections
+
+from dm_control import mjcf
+import numpy as np
+
+
+def create_torso(
+ model, bones, bone_position, lumbar_dofs_per_vertebra, side_sign, parent
+):
+ """Add torso in the dog model.
+
+ Args:
+ model: model in which we want to add the torso.
+ bones: a list of strings with all the names of the bones.
+ bone_position: a dictionary of bones positions.
+ lumbar_dofs_per_vertebra: a number that the determines how many dofs are
+ going to be used between each pair of lumbar vetebrae.
+ side_sign: a dictionary with two axis representing the signs of
+ translations.
+ parent: parent object on which we should start attaching new components.
+
+ Returns:
+ The tuple `(pelvic_bones, lumbar_joints)`.
+ """
+ # Lumbar Spine
+ def_lumbar_extend = model.default.find('default', 'lumbar_extend')
+ def_lumbar_bend = model.default.find('default', 'lumbar_bend')
+ def_lumbar_twist = model.default.find('default', 'lumbar_twist')
+ lumbar_defaults = {
+ 'extend': def_lumbar_extend,
+ 'bend': def_lumbar_bend,
+ 'twist': def_lumbar_twist,
+ }
+
+ thoracic_spine = [m for m in bones if 'T_' in m]
+ ribs = [m for m in bones if 'Rib' in m and 'cage' not in m]
+ sternum = [m for m in bones if 'Sternum' in m]
+ torso_bones = thoracic_spine + ribs + sternum # + ['Xiphoid_cartilage']
+ torso = parent.add('body', name='torso')
+ torso.add('freejoint', name='root')
+ torso.add('site', name='root', size=(0.01,), rgba=[0, 1, 0, 1])
+ torso.add('light', name='light', mode='trackcom', pos=[0, 0, 3])
+ torso.add(
+ 'camera',
+ name='y-axis',
+ mode='trackcom',
+ pos=[0, -1.5, 0.8],
+ xyaxes=[1, 0, 0, 0, 0.6, 1],
+ )
+ torso.add(
+ 'camera',
+ name='x-axis',
+ mode='trackcom',
+ pos=[2, 0, 0.5],
+ xyaxes=[0, 1, 0, -0.3, 0, 1],
+ )
+ torso_geoms = []
+ for bone in torso_bones:
+ torso_geoms.append(
+ torso.add('geom', name=bone, mesh=bone, dclass='light_bone')
+ )
+
+ # Reload, get CoM position, set pos
+ physics = mjcf.Physics.from_mjcf_model(model)
+ torso_pos = np.array(physics.bind(model.find('body', 'torso')).xipos)
+ torso.pos = torso_pos
+ for geom in torso_geoms:
+ geom.pos = -torso_pos
+
+ # Collision primitive for torso
+ torso.add(
+ 'geom',
+ name='collision_torso',
+ dclass='nonself_collision_primitive',
+ type='ellipsoid',
+ pos=[0, 0, 0],
+ size=[0.2, 0.09, 0.11],
+ euler=[0, 10, 0],
+ density=200,
+ )
+
+ # Lumbar spine bodies:
+ lumbar_bones = ['L_1', 'L_2', 'L_3', 'L_4', 'L_5', 'L_6', 'L_7']
+ parent = torso
+ parent_pos = torso_pos
+ lumbar_bodies = []
+ lumbar_geoms = []
+ for i, bone in enumerate(lumbar_bones):
+ bone_pos = bone_position[bone]
+ child = parent.add('body', name=bone, pos=bone_pos - parent_pos)
+ lumbar_bodies.append(child)
+ geom = child.add('geom', name=bone, mesh=bone, pos=-bone_pos, dclass='bone')
+ child.add(
+ 'geom',
+ name=bone + '_collision',
+ type='sphere',
+ size=[0.05],
+ pos=[0, 0, -0.02],
+ dclass='nonself_collision_primitive',
+ )
+ lumbar_geoms.append(geom)
+ parent = child
+ parent_pos = bone_pos
+ l_7 = parent
+
+ # Lumbar spine joints:
+ lumbar_axis = collections.OrderedDict()
+ lumbar_axis['extend'] = np.array((0.0, 1.0, 0.0))
+ lumbar_axis['bend'] = np.array((0.0, 0.0, 1.0))
+ lumbar_axis['twist'] = np.array((1.0, 0.0, 0))
+
+ num_dofs = 0
+ lumbar_joints = []
+ lumbar_joint_names = []
+ for i, vertebra in enumerate(lumbar_bodies):
+ while num_dofs < (i + 1) * lumbar_dofs_per_vertebra:
+ dof = num_dofs % 3
+ dof_name = list(lumbar_axis.keys())[dof]
+ dof_axis = lumbar_axis[dof_name]
+ lumbar_joint_names.append(vertebra.name + '_' + dof_name)
+ joint = vertebra.add(
+ 'joint',
+ name=lumbar_joint_names[-1],
+ dclass='lumbar_' + dof_name,
+ axis=dof_axis,
+ )
+ lumbar_joints.append(joint)
+ num_dofs += 1
+
+ # Scale joint defaults relative to 3 lumbar_dofs_per_veterbra
+ for dof in lumbar_axis.keys():
+ axis_scale = 7.0 / [dof in joint for joint in lumbar_joint_names].count(
+ True
+ )
+ lumbar_defaults[dof].joint.range *= axis_scale
+
+ # Pelvis:
+ pelvis = l_7.add(
+ 'body', name='pelvis', pos=bone_position['Pelvis'] - bone_position['L_7']
+ )
+ pelvic_bones = ['Sacrum', 'Pelvis']
+ pelvic_geoms = []
+ for bone in pelvic_bones:
+ geom = pelvis.add(
+ 'geom',
+ name=bone,
+ mesh=bone,
+ pos=-bone_position['Pelvis'],
+ dclass='bone',
+ )
+ pelvic_geoms.append(geom)
+ # Collision primitives for pelvis
+ for side in ['_L', '_R']:
+ pos = np.array((0.01, -0.02, -0.01)) * side_sign[side]
+ pelvis.add(
+ 'geom',
+ name='collision_pelvis' + side,
+ pos=pos,
+ size=[0.05, 0.05, 0],
+ euler=[0, 70, 0],
+ dclass='nonself_collision_primitive',
+ )
+
+ return pelvic_bones, lumbar_joints
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/create_skin.py b/dm_control/locomotion/walkers/assets/dog_v2/create_skin.py
new file mode 100644
index 00000000..3ef74f0e
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/create_skin.py
@@ -0,0 +1,210 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Make skin for the dog model."""
+
+import struct
+
+from dm_control import mjcf
+from dm_control.mujoco.wrapper.mjbindings import enums
+import numpy as np
+from scipy import spatial
+
+
+def create(model, skin_msh):
+ """Create and add skin in the dog model.
+
+ Args:
+ model: model in which we want to add the skin.
+ skin_msh: a binary mesh format of the skin.
+ """
+ print('Making Skin.')
+ # Add skin mesh:
+ skinmesh = model.worldbody.add(
+ 'geom',
+ name='skinmesh',
+ mesh='skin_msh',
+ type='mesh',
+ contype=0,
+ conaffinity=0,
+ rgba=[1, 0.5, 0.5, 0.5],
+ group=1,
+ euler=(0, 0, 90),
+ )
+ physics = mjcf.Physics.from_mjcf_model(model)
+
+ # Get skinmesh vertices in global coordinates
+ vertadr = physics.named.model.mesh_vertadr['skin_msh']
+ vertnum = physics.named.model.mesh_vertnum['skin_msh']
+ skin_vertices = physics.model.mesh_vert[vertadr : vertadr + vertnum, :]
+ skin_vertices = skin_vertices.dot(
+ physics.named.data.geom_xmat['skinmesh'].reshape(3, 3).T
+ )
+ skin_vertices += physics.named.data.geom_xpos['skinmesh']
+ skin_normals = physics.model.mesh_normal[vertadr : vertadr + vertnum, :]
+ skin_normals = skin_normals.dot(
+ physics.named.data.geom_xmat['skinmesh'].reshape(3, 3).T
+ )
+ skin_normals += physics.named.data.geom_xpos['skinmesh']
+
+ # Get skinmesh faces
+ faceadr = physics.named.model.mesh_faceadr['skin_msh']
+ facenum = physics.named.model.mesh_facenum['skin_msh']
+ skin_faces = physics.model.mesh_face[faceadr : faceadr + facenum, :]
+
+ # Make skin
+ skin = model.asset.add(
+ 'skin', name='skin', vertex=skin_vertices.ravel(), face=skin_faces.ravel()
+ )
+
+ # Functions for capsule vertices
+ numslices = 10
+ numstacks = 10
+ numquads = 8
+
+ def hemisphere(radius):
+ positions = []
+ for az in np.linspace(0, 2 * np.pi, numslices, False):
+ for el in np.linspace(0, np.pi, numstacks, False):
+ pos = np.asarray(
+ [np.cos(el) * np.cos(az), np.cos(el) * np.sin(az), np.sin(el)]
+ )
+ positions.append(pos)
+ return radius * np.asarray(positions)
+
+ def cylinder(radius, height):
+ positions = []
+ for az in np.linspace(0, 2 * np.pi, numslices, False):
+ for el in np.linspace(-1, 1, numstacks):
+ pos = np.asarray(
+ [radius * np.cos(az), radius * np.sin(az), height * el]
+ )
+ positions.append(pos)
+ return np.asarray(positions)
+
+ def capsule(radius, height):
+ hp = hemisphere(radius)
+ cy = cylinder(radius, height)
+ offset = np.array((0, 0, height))
+ return np.unique(np.vstack((cy, hp + offset, -hp - offset)), axis=0)
+
+ def ellipsoid(size):
+ hp = hemisphere(1)
+ sphere = np.unique(np.vstack((hp, -hp)), axis=0)
+ return sphere * size
+
+ def box(sx, sy, sz):
+ positions = []
+ for x in np.linspace(-sx, sx, numquads + 1):
+ for y in np.linspace(-sy, sy, numquads + 1):
+ for z in np.linspace(-sz, sz, numquads + 1):
+ if abs(x) == sx or abs(y) == sy or abs(z) == sz:
+ pos = np.asarray([x, y, z])
+ positions.append(pos)
+ return np.unique(np.asarray(positions), axis=0)
+
+ # Find smallest distance between
+ # each skin vertex and vertices of all meshes in body i
+ distance = np.zeros((skin_vertices.shape[0], physics.model.nbody))
+ for i in range(1, physics.model.nbody):
+ geom_id = np.argwhere(physics.model.geom_bodyid == i).ravel()
+ mesh_id = physics.model.geom_dataid[geom_id]
+ body_verts = []
+ for k, gid in enumerate(geom_id):
+ skip = False
+ if physics.model.geom_type[gid] == enums.mjtGeom.mjGEOM_MESH:
+ vertadr = physics.model.mesh_vertadr[mesh_id[k]]
+ vertnum = physics.model.mesh_vertnum[mesh_id[k]]
+ vertices = physics.model.mesh_vert[vertadr : vertadr + vertnum, :]
+ elif physics.model.geom_type[gid] == enums.mjtGeom.mjGEOM_CAPSULE:
+ radius = physics.model.geom_size[gid, 0]
+ height = physics.model.geom_size[gid, 1]
+ vertices = capsule(radius, height)
+ elif physics.model.geom_type[gid] == enums.mjtGeom.mjGEOM_ELLIPSOID:
+ vertices = ellipsoid(physics.model.geom_size[gid])
+ elif physics.model.geom_type[gid] == enums.mjtGeom.mjGEOM_BOX:
+ vertices = box(*physics.model.geom_size[gid])
+ else:
+ skip = True
+ if not skip:
+ vertices = vertices.dot(physics.data.geom_xmat[gid].reshape(3, 3).T)
+ vertices += physics.data.geom_xpos[gid]
+ body_verts.append(vertices)
+
+ body_verts = np.vstack((body_verts))
+ # hull = spatial.ConvexHull(body_verts)
+ tree = spatial.cKDTree(body_verts)
+ distance[:, i], _ = tree.query(skin_vertices)
+
+ # non-KDTree implementation of the above 2 lines:
+ # distance[:, i] = np.amin(
+ # spatial.distance.cdist(skin_vertices, body_verts, 'euclidean'),
+ # axis=1)
+
+ # Calculate bone weights from distances
+ sigma = 0.015
+ weights = np.exp(-distance[:, 1:] / sigma)
+ threshold = 0.01
+ weights /= np.atleast_2d(np.sum(weights, axis=1)).T
+ weights[weights < threshold] = 0
+ weights /= np.atleast_2d(np.sum(weights, axis=1)).T
+
+ for i in range(1, physics.model.nbody):
+ vertweight = weights[weights[:, i - 1] >= threshold, i - 1]
+ vertid = np.argwhere(weights[:, i - 1] >= threshold).ravel()
+ if vertid.any():
+ skin.add(
+ 'bone',
+ body=physics.model.id2name(i, 'body'),
+ bindquat=[1, 0, 0, 0],
+ bindpos=physics.data.xpos[i, :],
+ vertid=vertid,
+ vertweight=vertweight,
+ )
+
+ # Remove skinmesh
+ skinmesh.remove()
+
+ # Convert skin into *.skn file according to
+ # https://mujoco.readthedocs.io/en/latest/XMLreference.html#asset-skin
+ f = open('dog_skin.skn', 'w+b')
+ nvert = skin.vertex.size // 3
+ f.write(
+ struct.pack(
+ '4i', nvert, nvert, skin.face.size // 3, physics.model.nbody - 1
+ )
+ )
+ f.write(struct.pack(str(skin.vertex.size) + 'f', *skin.vertex))
+ assert physics.model.mesh_texcoord.shape[0] == physics.bind(skin_msh).vertnum
+ f.write(
+ struct.pack(str(2 * nvert) + 'f', *physics.model.mesh_texcoord.flatten())
+ )
+ f.write(struct.pack(str(skin.face.size) + 'i', *skin.face))
+ for bone in skin.bone:
+ name_length = len(bone.body)
+ assert name_length <= 40
+ f.write(
+ struct.pack(str(name_length) + 'c', *[s.encode() for s in bone.body])
+ )
+ f.write((40 - name_length) * b'\x00')
+ f.write(struct.pack('3f', *bone.bindpos))
+ f.write(struct.pack('4f', *bone.bindquat))
+ f.write(struct.pack('i', bone.vertid.size))
+ f.write(struct.pack(str(bone.vertid.size) + 'i', *bone.vertid))
+ f.write(struct.pack(str(bone.vertid.size) + 'f', *bone.vertweight))
+ f.close()
+
+ # Remove XML-based skin, add binary skin.
+ skin.remove()
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/dog.xml b/dm_control/locomotion/walkers/assets/dog_v2/dog.xml
new file mode 100644
index 00000000..daf7f537
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/dog.xml
@@ -0,0 +1,992 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/dog_base.xml b/dm_control/locomotion/walkers/assets/dog_v2/dog_base.xml
new file mode 100644
index 00000000..6a6210fd
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/dog_base.xml
@@ -0,0 +1,170 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/dog_skin.skn b/dm_control/locomotion/walkers/assets/dog_v2/dog_skin.skn
new file mode 100644
index 00000000..e11e0264
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/dog_v2/dog_skin.skn differ
diff --git a/dm_control/locomotion/walkers/assets/dog_v2/scene.xml b/dm_control/locomotion/walkers/assets/dog_v2/scene.xml
new file mode 100644
index 00000000..d70d343a
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/dog_v2/scene.xml
@@ -0,0 +1,50 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_1_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_1_body.msh
new file mode 100644
index 00000000..6fd0e732
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_1_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_1_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_1_lower.msh
new file mode 100644
index 00000000..fe06a608
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_1_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_2_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_2_body.msh
new file mode 100644
index 00000000..1d9e39f0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_2_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_2_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_2_lower.msh
new file mode 100644
index 00000000..50c6e35f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_2_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_3_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_3_body.msh
new file mode 100644
index 00000000..bb74f96f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_3_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_3_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_3_lower.msh
new file mode 100644
index 00000000..9c9b4fc6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_3_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_4_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_4_body.msh
new file mode 100644
index 00000000..18f5df14
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_4_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_4_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_4_lower.msh
new file mode 100644
index 00000000..d5536126
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_4_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_5_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_5_body.msh
new file mode 100644
index 00000000..df98adb7
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_5_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_5_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_5_lower.msh
new file mode 100644
index 00000000..c23560ef
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_5_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_6_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_6_body.msh
new file mode 100644
index 00000000..1112a798
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_6_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_6_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_6_lower.msh
new file mode 100644
index 00000000..eb6a4a1f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_6_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_7_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_7_body.msh
new file mode 100644
index 00000000..d6aba9c8
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_7_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_7_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_7_lower.msh
new file mode 100644
index 00000000..a1a40d27
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_7_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_8_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_8_body.msh
new file mode 100644
index 00000000..e5c53f11
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/abdomen_8_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_left_black.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_left_black.msh
new file mode 100644
index 00000000..353f67a6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_left_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_left_body.msh
new file mode 100644
index 00000000..c274dc87
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_right_black.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_right_black.msh
new file mode 100644
index 00000000..9a742a36
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_right_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_right_body.msh
new file mode 100644
index 00000000..a0493295
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/antenna_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/blender_model/drosophila_v2.blend b/dm_control/locomotion/walkers/assets/fruitfly_v2/blender_model/drosophila_v2.blend
new file mode 100755
index 00000000..bb622aa8
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/blender_model/drosophila_v2.blend differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T1_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T1_left_body.msh
new file mode 100644
index 00000000..9df055b0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T1_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T1_right_body.msh
new file mode 100644
index 00000000..990dd6c6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T2_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T2_left_body.msh
new file mode 100644
index 00000000..ef693f0e
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T2_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T2_right_body.msh
new file mode 100644
index 00000000..98be03ac
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T3_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T3_left_body.msh
new file mode 100644
index 00000000..256dc0cc
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T3_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T3_right_body.msh
new file mode 100644
index 00000000..3e7814a1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/coxa_T3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T1_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T1_left_body.msh
new file mode 100644
index 00000000..59d3fc7b
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T1_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T1_right_body.msh
new file mode 100644
index 00000000..0d385990
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T2_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T2_left_body.msh
new file mode 100644
index 00000000..cd49ae92
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T2_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T2_right_body.msh
new file mode 100644
index 00000000..2804b089
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T3_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T3_left_body.msh
new file mode 100644
index 00000000..a3ce197f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T3_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T3_right_body.msh
new file mode 100644
index 00000000..93819cee
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/femur_T3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/floor.xml b/dm_control/locomotion/walkers/assets/fruitfly_v2/floor.xml
new file mode 100644
index 00000000..b2bf705c
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/fruitfly_v2/floor.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/fruitfly.xml b/dm_control/locomotion/walkers/assets/fruitfly_v2/fruitfly.xml
new file mode 100644
index 00000000..a68e3e39
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/fruitfly_v2/fruitfly.xml
@@ -0,0 +1,918 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/haltere_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/haltere_left_body.msh
new file mode 100644
index 00000000..bc4d52d1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/haltere_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/haltere_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/haltere_right_body.msh
new file mode 100644
index 00000000..545cbab0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/haltere_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/haustellum_black.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/haustellum_black.msh
new file mode 100644
index 00000000..ce3c56f3
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/haustellum_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/haustellum_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/haustellum_body.msh
new file mode 100644
index 00000000..f9d83b99
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/haustellum_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/head_black.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_black.msh
new file mode 100644
index 00000000..ffb2f26b
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/head_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_body.msh
new file mode 100644
index 00000000..4236c6db
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/head_ocelli.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_ocelli.msh
new file mode 100644
index 00000000..4a4c281a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_ocelli.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/head_red.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_red.msh
new file mode 100644
index 00000000..53d30f86
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/head_red.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/labrum_left_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/labrum_left_lower.msh
new file mode 100644
index 00000000..f00f2256
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/labrum_left_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/labrum_right_lower.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/labrum_right_lower.msh
new file mode 100644
index 00000000..b183450c
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/labrum_right_lower.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/rostrum_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/rostrum_body.msh
new file mode 100644
index 00000000..f0bd54e3
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/rostrum_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/rostrum_bristle-brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/rostrum_bristle-brown.msh
new file mode 100644
index 00000000..f5dca535
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/rostrum_bristle-brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T1_left_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T1_left_brown.msh
new file mode 100644
index 00000000..104fa46f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T1_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T1_right_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T1_right_brown.msh
new file mode 100644
index 00000000..51a8ce90
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T1_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T2_left_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T2_left_brown.msh
new file mode 100644
index 00000000..9fd8426a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T2_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T2_right_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T2_right_brown.msh
new file mode 100644
index 00000000..bd8127c6
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T2_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T3_left_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T3_left_brown.msh
new file mode 100644
index 00000000..4c224a0d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T3_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T3_right_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T3_right_brown.msh
new file mode 100644
index 00000000..f1a12541
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsal_claw_T3_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_1_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_1_left_body.msh
new file mode 100644
index 00000000..0d2926cc
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_1_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_1_right_body.msh
new file mode 100644
index 00000000..b2238e9a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_2_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_2_left_body.msh
new file mode 100644
index 00000000..f8c0375f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_2_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_2_right_body.msh
new file mode 100644
index 00000000..81eba2ea
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_3_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_3_left_body.msh
new file mode 100644
index 00000000..ae1e169d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_3_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_3_right_body.msh
new file mode 100644
index 00000000..475817de
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_4_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_4_left_body.msh
new file mode 100644
index 00000000..116844cb
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_4_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_4_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_4_right_body.msh
new file mode 100644
index 00000000..97acf78d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T1_4_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_1_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_1_left_body.msh
new file mode 100644
index 00000000..46ed4fb4
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_1_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_1_right_body.msh
new file mode 100644
index 00000000..b49ee4c2
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_2_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_2_left_body.msh
new file mode 100644
index 00000000..7b81306c
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_2_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_2_right_body.msh
new file mode 100644
index 00000000..d1a7d6ef
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_3_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_3_left_body.msh
new file mode 100644
index 00000000..fa26f25a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_3_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_3_right_body.msh
new file mode 100644
index 00000000..e66cb2de
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_4_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_4_left_body.msh
new file mode 100644
index 00000000..72fb71ef
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_4_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_4_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_4_right_body.msh
new file mode 100644
index 00000000..563fb5e0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T2_4_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_1_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_1_left_body.msh
new file mode 100644
index 00000000..5cf0ad6f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_1_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_1_right_body.msh
new file mode 100644
index 00000000..45b35ea0
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_2_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_2_left_body.msh
new file mode 100644
index 00000000..3a6e0592
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_2_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_2_right_body.msh
new file mode 100644
index 00000000..7ee5024e
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_3_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_3_left_body.msh
new file mode 100644
index 00000000..9fca2447
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_3_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_3_right_body.msh
new file mode 100644
index 00000000..416e429d
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_4_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_4_left_body.msh
new file mode 100644
index 00000000..ae032249
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_4_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_4_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_4_right_body.msh
new file mode 100644
index 00000000..fb4a36d5
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tarsus_T3_4_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/thorax_black.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/thorax_black.msh
new file mode 100644
index 00000000..652ba9fc
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/thorax_black.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/thorax_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/thorax_body.msh
new file mode 100644
index 00000000..31965586
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/thorax_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T1_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T1_left_body.msh
new file mode 100644
index 00000000..0c6104c1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T1_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T1_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T1_right_body.msh
new file mode 100644
index 00000000..533335e2
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T1_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T2_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T2_left_body.msh
new file mode 100644
index 00000000..ba38624c
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T2_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T2_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T2_right_body.msh
new file mode 100644
index 00000000..61ab8cf5
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T2_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T3_left_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T3_left_body.msh
new file mode 100644
index 00000000..c30f9d4f
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T3_left_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T3_right_body.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T3_right_body.msh
new file mode 100644
index 00000000..8a5575a1
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/tibia_T3_right_body.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_left_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_left_brown.msh
new file mode 100644
index 00000000..7c30f774
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_left_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_left_membrane.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_left_membrane.msh
new file mode 100644
index 00000000..a07ca11a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_left_membrane.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_right_brown.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_right_brown.msh
new file mode 100644
index 00000000..ee5f18ec
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_right_brown.msh differ
diff --git a/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_right_membrane.msh b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_right_membrane.msh
new file mode 100644
index 00000000..3550adf8
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/fruitfly_v2/wing_right_membrane.msh differ
diff --git a/dm_control/locomotion/walkers/assets/humanoid_CMU_V2019.xml b/dm_control/locomotion/walkers/assets/humanoid_CMU_V2019.xml
new file mode 100644
index 00000000..e8c59f01
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/humanoid_CMU_V2019.xml
@@ -0,0 +1,298 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/humanoid_CMU_V2020.xml b/dm_control/locomotion/walkers/assets/humanoid_CMU_V2020.xml
new file mode 100644
index 00000000..d13b16ab
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/humanoid_CMU_V2020.xml
@@ -0,0 +1,304 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/jumping_ball/jumping_ball_body.png b/dm_control/locomotion/walkers/assets/jumping_ball/jumping_ball_body.png
new file mode 100644
index 00000000..8ab3947a
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/jumping_ball/jumping_ball_body.png differ
diff --git a/dm_control/locomotion/walkers/assets/jumping_ball/jumping_ball_with_head.xml b/dm_control/locomotion/walkers/assets/jumping_ball/jumping_ball_with_head.xml
new file mode 100644
index 00000000..3968f67e
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/jumping_ball/jumping_ball_with_head.xml
@@ -0,0 +1,59 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/rodent.xml b/dm_control/locomotion/walkers/assets/rodent.xml
new file mode 100644
index 00000000..2dbbb9cf
--- /dev/null
+++ b/dm_control/locomotion/walkers/assets/rodent.xml
@@ -0,0 +1,611 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn b/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn
new file mode 100644
index 00000000..651fdc18
Binary files /dev/null and b/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn differ
diff --git a/dm_control/locomotion/walkers/base.py b/dm_control/locomotion/walkers/base.py
new file mode 100644
index 00000000..de7a6f20
--- /dev/null
+++ b/dm_control/locomotion/walkers/base.py
@@ -0,0 +1,200 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for Walkers."""
+
+import abc
+import collections
+
+from dm_control import composer
+from dm_control.composer.observation import observable
+
+from dm_env import specs
+import numpy as np
+
+
+def _make_readonly_float64_copy(value):
+ if np.isscalar(value):
+ return np.float64(value)
+ else:
+ out = np.array(value, dtype=np.float64)
+ out.flags.writeable = False
+ return out
+
+
+class WalkerPose(collections.namedtuple(
+ 'WalkerPose', ('qpos', 'xpos', 'xquat'))):
+ """A named tuple representing a walker's joint and Cartesian pose."""
+
+ __slots__ = ()
+
+ def __new__(cls, qpos=None, xpos=(0, 0, 0), xquat=(1, 0, 0, 0)):
+ """Creates a new WalkerPose.
+
+ Args:
+ qpos: The joint position for the pose, or `None` if the `qpos0` values in
+ the `mjModel` should be used.
+ xpos: A Cartesian displacement, for example if the walker should be lifted
+ or lowered by a specific amount for this pose.
+ xquat: A quaternion displacement for the root body.
+
+ Returns:
+ A new instance of `WalkerPose`.
+ """
+ return super(WalkerPose, cls).__new__(
+ cls,
+ qpos=_make_readonly_float64_copy(qpos) if qpos is not None else None,
+ xpos=_make_readonly_float64_copy(xpos),
+ xquat=_make_readonly_float64_copy(xquat))
+
+ def __eq__(self, other):
+ return (np.all(self.qpos == other.qpos) and
+ np.all(self.xpos == other.xpos) and
+ np.all(self.xquat == other.xquat))
+
+
+class Walker(composer.Robot, metaclass=abc.ABCMeta):
+ """Abstract base class for Walker robots."""
+
+ def create_root_joints(self, attachment_frame) -> None:
+ attachment_frame.add('freejoint')
+
+ def _build_observables(self) -> 'WalkerObservables':
+ return WalkerObservables(self)
+
+ def transform_vec_to_egocentric_frame(self, physics, vec_in_world_frame):
+ """Linearly transforms a world-frame vector into walker's egocentric frame.
+
+ Note that this function does not perform an affine transformation of the
+ vector. In other words, the input vector is assumed to be specified with
+ respect to the same origin as this walker's egocentric frame. This function
+ can also be applied to matrices whose innermost dimensions are either 2 or
+ 3. In this case, a matrix with the same leading dimensions is returned
+ where the innermost vectors are replaced by their values computed in the
+ egocentric frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ vec_in_world_frame: A NumPy array with last dimension of shape (2,) or
+ (3,) that represents a vector quantity in the world frame.
+
+ Returns:
+ The same quantity as `vec_in_world_frame` but reexpressed in this
+ entity's egocentric frame. The returned np.array has the same shape as
+ np.asarray(vec_in_world_frame).
+
+ Raises:
+ ValueError: if `vec_in_world_frame` does not have shape ending with (2,)
+ or (3,).
+ """
+ return super().global_vector_to_local_frame(physics, vec_in_world_frame)
+
+ def transform_xmat_to_egocentric_frame(self, physics, xmat):
+ """Transforms another entity's `xmat` into this walker's egocentric frame.
+
+ This function takes another entity's (E) xmat, which is an SO(3) matrix
+ from E's frame to the world frame, and turns it to a matrix that transforms
+ from E's frame into this walker's egocentric frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ xmat: A NumPy array of shape (3, 3) or (9,) that represents another
+ entity's xmat.
+
+ Returns:
+ The `xmat` reexpressed in this entity's egocentric frame. The returned
+ np.array has the same shape as np.asarray(xmat).
+
+ Raises:
+ ValueError: if `xmat` does not have shape (3, 3) or (9,).
+ """
+ return super().global_xmat_to_local_frame(physics, xmat)
+
+ @property
+ @abc.abstractmethod
+ def root_body(self):
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def observable_joints(self):
+ raise NotImplementedError
+
+ @property
+ def action_spec(self):
+ if not self.actuators:
+ minimum, maximum = (), ()
+ else:
+ minimum, maximum = zip(*[
+ a.ctrlrange if a.ctrlrange is not None else (-1., 1.)
+ for a in self.actuators
+ ])
+ return specs.BoundedArray(
+ shape=(len(self.actuators),),
+ dtype=float,
+ minimum=minimum,
+ maximum=maximum,
+ name='\t'.join([actuator.name for actuator in self.actuators]))
+
+ def apply_action(self, physics, action, random_state):
+ """Apply action to walker's actuators."""
+ del random_state
+ physics.bind(self.actuators).ctrl = action
+
+
+class WalkerObservables(composer.Observables):
+ """Base class for Walker obserables."""
+
+ @composer.observable
+ def joints_pos(self):
+ return observable.MJCFFeature('qpos', self._entity.observable_joints)
+
+ @composer.observable
+ def sensors_gyro(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.gyro)
+
+ @composer.observable
+ def sensors_accelerometer(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.accelerometer)
+
+ @composer.observable
+ def sensors_framequat(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.framequat)
+
+ # Semantic groupings of Walker observables.
+ def _collect_from_attachments(self, attribute_name):
+ out = []
+ for entity in self._entity.iter_entities(exclude_self=True):
+ out.extend(getattr(entity.observables, attribute_name, []))
+ return out
+
+ @property
+ def proprioception(self):
+ return ([self.joints_pos] +
+ self._collect_from_attachments('proprioception'))
+
+ @property
+ def kinematic_sensors(self):
+ return ([self.sensors_gyro,
+ self.sensors_accelerometer,
+ self.sensors_framequat] +
+ self._collect_from_attachments('kinematic_sensors'))
+
+ @property
+ def dynamic_sensors(self):
+ return self._collect_from_attachments('dynamic_sensors')
diff --git a/dm_control/locomotion/walkers/base_test.py b/dm_control/locomotion/walkers/base_test.py
new file mode 100644
index 00000000..d8135f29
--- /dev/null
+++ b/dm_control/locomotion/walkers/base_test.py
@@ -0,0 +1,94 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.walkers.base."""
+
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.walkers import base
+import numpy as np
+
+
+class FakeWalker(base.Walker):
+
+ def _build(self):
+ self._mjcf_root = mjcf.RootElement(model='walker')
+ self._torso_body = self._mjcf_root.worldbody.add(
+ 'body', name='torso', xyaxes=[0, 1, 0, -1, 0, 0])
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def actuators(self):
+ return []
+
+ @property
+ def root_body(self):
+ return self._torso_body
+
+ @property
+ def observable_joints(self):
+ return []
+
+
+class BaseWalkerTest(absltest.TestCase):
+
+ def testTransformVectorToEgocentricFrame(self):
+ walker = FakeWalker()
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+
+ # 3D vectors
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [0, 1, 0]), [1, 0, 0],
+ atol=1e-10)
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [-1, 0, 0]),
+ [0, 1, 0],
+ atol=1e-10)
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [0, 0, 1]), [0, 0, 1],
+ atol=1e-10)
+
+ # 2D vectors; z-component is ignored
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [0, 1]), [1, 0],
+ atol=1e-10)
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [-1, 0]), [0, 1],
+ atol=1e-10)
+
+ def testTransformMatrixToEgocentricFrame(self):
+ walker = FakeWalker()
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+
+ rotation_atob = np.array([[0, 1, 0], [0, 0, -1], [-1, 0, 0]])
+ ego_rotation_atob = np.array([[0, 0, -1], [0, -1, 0], [-1, 0, 0]])
+
+ np.testing.assert_allclose(
+ walker.transform_xmat_to_egocentric_frame(physics, rotation_atob),
+ ego_rotation_atob, atol=1e-10)
+
+ flat_rotation_atob = np.reshape(rotation_atob, -1)
+ flat_rotation_ego_atob = np.reshape(ego_rotation_atob, -1)
+ np.testing.assert_allclose(
+ walker.transform_xmat_to_egocentric_frame(physics, flat_rotation_atob),
+ flat_rotation_ego_atob, atol=1e-10)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/walkers/cmu_humanoid.py b/dm_control/locomotion/walkers/cmu_humanoid.py
new file mode 100644
index 00000000..3054ca67
--- /dev/null
+++ b/dm_control/locomotion/walkers/cmu_humanoid.py
@@ -0,0 +1,494 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A CMU humanoid walker."""
+
+import abc
+import collections
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import legacy_base
+from dm_control.locomotion.walkers import rescale
+from dm_control.locomotion.walkers import scaled_actuators
+from dm_control.mujoco import wrapper as mj_wrapper
+import numpy as np
+
+_XML_PATH = os.path.join(os.path.dirname(__file__),
+ 'assets/humanoid_CMU_V{model_version}.xml')
+_WALKER_GEOM_GROUP = 2
+_WALKER_INVIS_GROUP = 1
+
+_CMU_MOCAP_JOINTS = (
+ 'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx',
+ 'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz',
+ 'rfootrx', 'rtoesrx', 'lowerbackrz', 'lowerbackry', 'lowerbackrx',
+ 'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz', 'thoraxry',
+ 'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz',
+ 'upperneckry', 'upperneckrx', 'headrz', 'headry', 'headrx', 'lclaviclerz',
+ 'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx',
+ 'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx', 'lthumbrz', 'lthumbrx',
+ 'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx',
+ 'rradiusrx', 'rwristry', 'rhandrz', 'rhandrx', 'rfingersrx', 'rthumbrz',
+ 'rthumbrx')
+
+
+# pylint: disable=bad-whitespace
+PositionActuatorParams = collections.namedtuple(
+ 'PositionActuatorParams', ['name', 'forcerange', 'kp'])
+_POSITION_ACTUATORS = [
+ PositionActuatorParams('headrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('headry', [-20, 20 ], 20 ),
+ PositionActuatorParams('headrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lclaviclery', [-20, 20 ], 20 ),
+ PositionActuatorParams('lclaviclerz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lfemurrx', [-120, 120], 120),
+ PositionActuatorParams('lfemurry', [-80, 80 ], 80 ),
+ PositionActuatorParams('lfemurrz', [-80, 80 ], 80 ),
+ PositionActuatorParams('lfingersrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lfootrx', [-50, 50 ], 50 ),
+ PositionActuatorParams('lfootrz', [-50, 50 ], 50 ),
+ PositionActuatorParams('lhandrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lhandrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lhumerusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('lhumerusry', [-60, 60 ], 60 ),
+ PositionActuatorParams('lhumerusrz', [-60, 60 ], 60 ),
+ PositionActuatorParams('lowerbackrx', [-120, 120], 150),
+ PositionActuatorParams('lowerbackry', [-120, 120], 150),
+ PositionActuatorParams('lowerbackrz', [-120, 120], 150),
+ PositionActuatorParams('lowerneckrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lowerneckry', [-20, 20 ], 20 ),
+ PositionActuatorParams('lowerneckrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lradiusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('lthumbrx', [-20, 20 ], 20) ,
+ PositionActuatorParams('lthumbrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('ltibiarx', [-80, 80 ], 80 ),
+ PositionActuatorParams('ltoesrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lwristry', [-20, 20 ], 20 ),
+ PositionActuatorParams('rclaviclery', [-20, 20 ], 20 ),
+ PositionActuatorParams('rclaviclerz', [-20, 20 ], 20 ),
+ PositionActuatorParams('rfemurrx', [-120, 120], 120),
+ PositionActuatorParams('rfemurry', [-80, 80 ], 80 ),
+ PositionActuatorParams('rfemurrz', [-80, 80 ], 80 ),
+ PositionActuatorParams('rfingersrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rfootrx', [-50, 50 ], 50 ),
+ PositionActuatorParams('rfootrz', [-50, 50 ], 50 ),
+ PositionActuatorParams('rhandrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rhandrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('rhumerusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('rhumerusry', [-60, 60 ], 60 ),
+ PositionActuatorParams('rhumerusrz', [-60, 60 ], 60 ),
+ PositionActuatorParams('rradiusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('rthumbrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rthumbrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('rtibiarx', [-80, 80 ], 80 ),
+ PositionActuatorParams('rtoesrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rwristry', [-20, 20 ], 20 ),
+ PositionActuatorParams('thoraxrx', [-80, 80 ], 100),
+ PositionActuatorParams('thoraxry', [-80, 80 ], 100),
+ PositionActuatorParams('thoraxrz', [-80, 80 ], 100),
+ PositionActuatorParams('upperbackrx', [-80, 80 ], 80 ),
+ PositionActuatorParams('upperbackry', [-80, 80 ], 80 ),
+ PositionActuatorParams('upperbackrz', [-80, 80 ], 80 ),
+ PositionActuatorParams('upperneckrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('upperneckry', [-20, 20 ], 20 ),
+ PositionActuatorParams('upperneckrz', [-20, 20 ], 20 ),
+]
+PositionActuatorParamsV2020 = collections.namedtuple(
+ 'PositionActuatorParams', ['name', 'forcerange', 'kp', 'damping'])
+_POSITION_ACTUATORS_V2020 = [
+ PositionActuatorParamsV2020('headrx', [-40, 40 ], 40 , 2 ),
+ PositionActuatorParamsV2020('headry', [-40, 40 ], 40 , 2 ),
+ PositionActuatorParamsV2020('headrz', [-40, 40 ], 40 , 2 ),
+ PositionActuatorParamsV2020('lclaviclery', [-80, 80 ], 80 , 20),
+ PositionActuatorParamsV2020('lclaviclerz', [-80, 80 ], 80 , 20),
+ PositionActuatorParamsV2020('lfemurrx', [-300, 300], 300, 15),
+ PositionActuatorParamsV2020('lfemurry', [-200, 200], 200, 10),
+ PositionActuatorParamsV2020('lfemurrz', [-200, 200], 200, 10),
+ PositionActuatorParamsV2020('lfingersrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('lfootrx', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('lfootrz', [-50, 50 ], 50 , 3 ),
+ PositionActuatorParamsV2020('lhandrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('lhandrz', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('lhumerusrx', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('lhumerusry', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('lhumerusrz', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('lowerbackrx', [-300, 300], 300, 15),
+ PositionActuatorParamsV2020('lowerbackry', [-180, 180], 180, 20),
+ PositionActuatorParamsV2020('lowerbackrz', [-200, 200], 200, 20),
+ PositionActuatorParamsV2020('lowerneckrx', [-120, 120 ],120, 20),
+ PositionActuatorParamsV2020('lowerneckry', [-120, 120 ],120, 20),
+ PositionActuatorParamsV2020('lowerneckrz', [-120, 120 ],120, 20),
+ PositionActuatorParamsV2020('lradiusrx', [-90, 90 ], 90 , 5 ),
+ PositionActuatorParamsV2020('lthumbrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('lthumbrz', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('ltibiarx', [-160, 160], 160, 8 ),
+ PositionActuatorParamsV2020('ltoesrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('lwristry', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('rclaviclery', [-80, 80 ], 80 , 20),
+ PositionActuatorParamsV2020('rclaviclerz', [-80, 80 ], 80 , 20),
+ PositionActuatorParamsV2020('rfemurrx', [-300, 300], 300, 15),
+ PositionActuatorParamsV2020('rfemurry', [-200, 200], 200, 10),
+ PositionActuatorParamsV2020('rfemurrz', [-200, 200], 200, 10),
+ PositionActuatorParamsV2020('rfingersrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('rfootrx', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('rfootrz', [-50, 50 ], 50 , 3 ),
+ PositionActuatorParamsV2020('rhandrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('rhandrz', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('rhumerusrx', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('rhumerusry', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('rhumerusrz', [-120, 120], 120, 6 ),
+ PositionActuatorParamsV2020('rradiusrx', [-90, 90 ], 90 , 5 ),
+ PositionActuatorParamsV2020('rthumbrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('rthumbrz', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('rtibiarx', [-160, 160], 160, 8 ),
+ PositionActuatorParamsV2020('rtoesrx', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('rwristry', [-20, 20 ], 20 , 1 ),
+ PositionActuatorParamsV2020('thoraxrx', [-300, 300], 300, 15),
+ PositionActuatorParamsV2020('thoraxry', [-80, 80], 80 , 8 ),
+ PositionActuatorParamsV2020('thoraxrz', [-200, 200], 200, 12),
+ PositionActuatorParamsV2020('upperbackrx', [-300, 300], 300, 15),
+ PositionActuatorParamsV2020('upperbackry', [-80, 80], 80 , 8 ),
+ PositionActuatorParamsV2020('upperbackrz', [-200, 200], 200, 12),
+ PositionActuatorParamsV2020('upperneckrx', [-60, 60 ], 60 , 10),
+ PositionActuatorParamsV2020('upperneckry', [-60, 60 ], 60 , 10),
+ PositionActuatorParamsV2020('upperneckrz', [-60, 60 ], 60 , 10),
+]
+
+# pylint: enable=bad-whitespace
+
+_UPRIGHT_POS = (0.0, 0.0, 0.94)
+_UPRIGHT_POS_V2020 = (0.0, 0.0, 1.143)
+_UPRIGHT_QUAT = (0.859, 1.0, 1.0, 0.859)
+
+# Height of head above which the humanoid is considered standing.
+_STAND_HEIGHT = 1.5
+
+_TORQUE_THRESHOLD = 60
+
+
+class _CMUHumanoidBase(legacy_base.Walker, metaclass=abc.ABCMeta):
+ """The abstract base class for walkers compatible with the CMU humanoid."""
+
+ def _build(self,
+ name='walker',
+ marker_rgba=None,
+ include_face=False,
+ initializer=None):
+ self._mjcf_root = mjcf.from_path(self._xml_path)
+ if name:
+ self._mjcf_root.model = name
+
+ # Set corresponding marker color if specified.
+ if marker_rgba is not None:
+ for geom in self.marker_geoms:
+ geom.set_attributes(rgba=marker_rgba)
+
+ self._actuator_order = np.argsort(_CMU_MOCAP_JOINTS)
+ self._inverse_order = np.argsort(self._actuator_order)
+
+ super()._build(initializer=initializer)
+
+ if include_face:
+ head = self._mjcf_root.find('body', 'head')
+ head.add(
+ 'geom',
+ type='capsule',
+ name='face',
+ size=(0.065, 0.014),
+ pos=(0.000341465, 0.048184, 0.01),
+ quat=(0.717887, 0.696142, -0.00493334, 0),
+ mass=0.,
+ contype=0,
+ conaffinity=0)
+
+ face_forwardness = head.pos[1]-.02
+ head_geom = self._mjcf_root.find('geom', 'head')
+ nose_size = head_geom.size[0] / 4.75
+ face = head.add(
+ 'body', name='face', pos=(0.0, 0.039, face_forwardness))
+ face.add('geom',
+ type='capsule',
+ name='nose',
+ size=(nose_size, 0.01),
+ pos=(0.0, 0.0, 0.0),
+ quat=(1, 0.7, 0, 0),
+ mass=0.,
+ contype=0,
+ conaffinity=0,
+ group=_WALKER_INVIS_GROUP)
+
+ def _build_observables(self):
+ return CMUHumanoidObservables(self)
+
+ @property
+ @abc.abstractmethod
+ def _xml_path(self):
+ raise NotImplementedError
+
+ @composer.cached_property
+ def mocap_joints(self):
+ return tuple(
+ self._mjcf_root.find('joint', name) for name in _CMU_MOCAP_JOINTS)
+
+ @property
+ def actuator_order(self):
+ """Index of joints from the CMU mocap dataset sorted alphabetically by name.
+
+ Actuators in this walkers are ordered alphabetically by name. This property
+ provides a mapping between from actuator ordering to canonical CMU ordering.
+
+ Returns:
+ A list of integers corresponding to joint indices from the CMU dataset.
+ Specifically, the n-th element in the list is the index of the CMU joint
+ index that corresponds to the n-th actuator in this walker.
+ """
+ return self._actuator_order
+
+ @property
+ def actuator_to_joint_order(self):
+ """Index of actuators corresponding to each CMU mocap joint.
+
+ Actuators in this walkers are ordered alphabetically by name. This property
+ provides a mapping between from canonical CMU ordering to actuator ordering.
+
+ Returns:
+ A list of integers corresponding to actuator indices within this walker.
+ Specifically, the n-th element in the list is the index of the actuator
+ in this walker that corresponds to the n-th joint from the CMU mocap
+ dataset.
+ """
+ return self._inverse_order
+
+ @property
+ def upright_pose(self):
+ return base.WalkerPose(xpos=_UPRIGHT_POS, xquat=_UPRIGHT_QUAT)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @composer.cached_property
+ def actuators(self):
+ return tuple(self._mjcf_root.find_all('actuator'))
+
+ @composer.cached_property
+ def root_body(self):
+ return self._mjcf_root.find('body', 'root')
+
+ @composer.cached_property
+ def head(self):
+ return self._mjcf_root.find('body', 'head')
+
+ @composer.cached_property
+ def left_arm_root(self):
+ return self._mjcf_root.find('body', 'lclavicle')
+
+ @composer.cached_property
+ def right_arm_root(self):
+ return self._mjcf_root.find('body', 'rclavicle')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ return tuple(self._mjcf_root.find('body', 'lfoot').find_all('geom') +
+ self._mjcf_root.find('body', 'rfoot').find_all('geom'))
+
+ @composer.cached_property
+ def standing_height(self):
+ return _STAND_HEIGHT
+
+ @composer.cached_property
+ def end_effectors(self):
+ return (self._mjcf_root.find('body', 'rradius'),
+ self._mjcf_root.find('body', 'lradius'),
+ self._mjcf_root.find('body', 'rfoot'),
+ self._mjcf_root.find('body', 'lfoot'))
+
+ @composer.cached_property
+ def observable_joints(self):
+ return tuple(actuator.joint for actuator in self.actuators
+ if actuator.joint is not None)
+
+ @composer.cached_property
+ def bodies(self):
+ return tuple(self._mjcf_root.find_all('body'))
+
+ @composer.cached_property
+ def mocap_tracking_bodies(self):
+ """Collection of bodies for mocap tracking."""
+ # remove root body
+ root_body = self._mjcf_root.find('body', 'root')
+ return tuple(
+ b for b in self._mjcf_root.find_all('body') if b != root_body)
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ @composer.cached_property
+ def body_camera(self):
+ return self._mjcf_root.find('camera', 'bodycam')
+
+ @property
+ def marker_geoms(self):
+ return (self._mjcf_root.find('geom', 'rradius'),
+ self._mjcf_root.find('geom', 'lradius'))
+
+
+class CMUHumanoid(_CMUHumanoidBase):
+ """A CMU humanoid walker."""
+
+ @property
+ def _xml_path(self):
+ return _XML_PATH.format(model_version='2019')
+
+
+class CMUHumanoidPositionControlled(CMUHumanoid):
+ """A position-controlled CMU humanoid with control range scaled to [-1, 1]."""
+
+ def _build(self, model_version='2019', **kwargs):
+ self._version = model_version
+ if 'scale_default' in kwargs:
+ scale_default = kwargs['scale_default']
+ del kwargs['scale_default']
+ else:
+ scale_default = False
+
+ super()._build(**kwargs)
+
+ if scale_default:
+ # NOTE: This rescaling doesn't affect the attached hands
+ rescale.rescale_humanoid(self, 1.2, 1.2, 70)
+
+ # modify actuators
+ if self._version == '2020':
+ position_actuators = _POSITION_ACTUATORS_V2020
+ else:
+ position_actuators = _POSITION_ACTUATORS
+ self._mjcf_root.default.general.forcelimited = 'true'
+ self._mjcf_root.actuator.motor.clear()
+ for actuator_params in position_actuators:
+ associated_joint = self._mjcf_root.find('joint', actuator_params.name)
+ if hasattr(actuator_params, 'damping'):
+ associated_joint.damping = actuator_params.damping
+ actuator = scaled_actuators.add_position_actuator(
+ name=actuator_params.name,
+ target=associated_joint,
+ kp=actuator_params.kp,
+ qposrange=associated_joint.range,
+ ctrlrange=(-1, 1),
+ forcerange=actuator_params.forcerange)
+ if self._version == '2020':
+ actuator.dyntype = 'filter'
+ actuator.dynprm = [0.030]
+ limits = zip(*(actuator.joint.range for actuator in self.actuators)) # pylint: disable=not-an-iterable
+ lower, upper = (np.array(limit) for limit in limits)
+ self._scale = upper - lower
+ self._offset = upper + lower
+
+ @property
+ def _xml_path(self):
+ return _XML_PATH.format(model_version=self._version)
+
+ def cmu_pose_to_actuation(self, target_pose):
+ """Creates the control signal corresponding a CMU mocap joints pose.
+
+ Args:
+ target_pose: An array containing the target position for each joint.
+ These must be given in "canonical CMU order" rather than "qpos order",
+ i.e. the order of `target_pose[self.actuator_order]` should correspond
+ to the order of `physics.bind(self.actuators).ctrl`.
+
+ Returns:
+ An array of the same shape as `target_pose` containing inputs for position
+ controllers. Writing these values into `physics.bind(self.actuators).ctrl`
+ will cause the actuators to drive joints towards `target_pose`.
+ """
+ return (2 * target_pose[self.actuator_order] - self._offset) / self._scale
+
+
+class CMUHumanoidPositionControlledV2020(CMUHumanoidPositionControlled):
+ """A 2020 updated CMU humanoid walker; includes nose for head orientation."""
+
+ def _build(self, **kwargs):
+ super()._build(
+ model_version='2020', scale_default=True, include_face=True, **kwargs)
+
+ @property
+ def upright_pose(self):
+ return base.WalkerPose(xpos=_UPRIGHT_POS_V2020, xquat=_UPRIGHT_QUAT)
+
+
+class CMUHumanoidObservables(legacy_base.WalkerObservables):
+ """Observables for the Humanoid."""
+
+ @composer.observable
+ def body_camera(self):
+ options = mj_wrapper.MjvOption()
+
+ # Don't render this walker's geoms.
+ options.geomgroup[_WALKER_GEOM_GROUP] = 0
+ return observable.MJCFCamera(
+ self._entity.body_camera, width=64, height=64, scene_option=options)
+
+ @composer.observable
+ def egocentric_camera(self):
+ options = mj_wrapper.MjvOption()
+
+ # Don't render this walker's geoms.
+ options.geomgroup[_WALKER_INVIS_GROUP] = 0
+ return observable.MJCFCamera(self._entity.egocentric_camera,
+ width=64, height=64, scene_option=options)
+
+ @composer.observable
+ def head_height(self):
+ return observable.MJCFFeature('xpos', self._entity.head)[2]
+
+ @composer.observable
+ def sensors_torque(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.torque,
+ corruptor=lambda v, random_state: np.tanh(2 * v / _TORQUE_THRESHOLD))
+
+ @composer.observable
+ def actuator_activation(self):
+ return observable.MJCFFeature('act',
+ self._entity.mjcf_model.find_all('actuator'))
+
+ @composer.observable
+ def appendages_pos(self):
+ """Equivalent to `end_effectors_pos` with the head's position appended."""
+ def relative_pos_in_egocentric_frame(physics):
+ end_effectors_with_head = (
+ self._entity.end_effectors + (self._entity.head,))
+ end_effector = physics.bind(end_effectors_with_head).xpos
+ torso = physics.bind(self._entity.root_body).xpos
+ xmat = np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3))
+ return np.reshape(np.dot(end_effector - torso, xmat), -1)
+ return observable.Generic(relative_pos_in_egocentric_frame)
+
+ @property
+ def proprioception(self):
+ return [
+ self.joints_pos,
+ self.joints_vel,
+ self.actuator_activation,
+ self.body_height,
+ self.end_effectors_pos,
+ self.appendages_pos,
+ self.world_zaxis
+ ] + self._collect_from_attachments('proprioception')
diff --git a/dm_control/locomotion/walkers/cmu_humanoid_test.py b/dm_control/locomotion/walkers/cmu_humanoid_test.py
new file mode 100644
index 00000000..4eb9742f
--- /dev/null
+++ b/dm_control/locomotion/walkers/cmu_humanoid_test.py
@@ -0,0 +1,161 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the CMU humanoid."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base as observable_base
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+
+
+class CMUHumanoidTest(parameterized.TestCase):
+
+ @parameterized.parameters([
+ cmu_humanoid.CMUHumanoid,
+ cmu_humanoid.CMUHumanoidPositionControlled,
+ ])
+ def test_can_compile_and_step_simulation(self, walker_type):
+ walker = walker_type()
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+ for _ in range(100):
+ physics.step()
+
+ @parameterized.parameters([
+ cmu_humanoid.CMUHumanoid,
+ cmu_humanoid.CMUHumanoidPositionControlled,
+ ])
+ def test_actuators_sorted_alphabetically(self, walker_type):
+ walker = walker_type()
+ actuator_names = [
+ actuator.name for actuator in walker.mjcf_model.find_all('actuator')]
+ np.testing.assert_array_equal(actuator_names, sorted(actuator_names))
+
+ def test_actuator_to_mocap_joint_mapping(self):
+ walker = cmu_humanoid.CMUHumanoid()
+
+ with self.subTest('Forward mapping'):
+ for actuator_num, cmu_mocap_joint_num in enumerate(walker.actuator_order):
+ self.assertEqual(walker.actuator_to_joint_order[cmu_mocap_joint_num],
+ actuator_num)
+
+ with self.subTest('Inverse mapping'):
+ for cmu_mocap_joint_num, actuator_num in enumerate(
+ walker.actuator_to_joint_order):
+ self.assertEqual(walker.actuator_order[actuator_num],
+ cmu_mocap_joint_num)
+
+ def test_cmu_humanoid_position_controlled_has_correct_actuators(self):
+ walker_torque = cmu_humanoid.CMUHumanoid()
+ walker_pos = cmu_humanoid.CMUHumanoidPositionControlled()
+
+ actuators_torque = walker_torque.mjcf_model.find_all('actuator')
+ actuators_pos = walker_pos.mjcf_model.find_all('actuator')
+
+ actuator_pos_params = {
+ params.name: params for params in cmu_humanoid._POSITION_ACTUATORS}
+
+ self.assertEqual(len(actuators_torque), len(actuators_pos))
+
+ for actuator_torque, actuator_pos in zip(actuators_torque, actuators_pos):
+ self.assertEqual(actuator_pos.name, actuator_torque.name)
+ self.assertEqual(actuator_pos.joint.full_identifier,
+ actuator_torque.joint.full_identifier)
+ self.assertEqual(actuator_pos.tag, 'general')
+ self.assertEqual(actuator_pos.ctrllimited, 'true')
+ np.testing.assert_array_equal(actuator_pos.ctrlrange, (-1, 1))
+
+ expected_params = actuator_pos_params[actuator_pos.name]
+ self.assertEqual(actuator_pos.biasprm[1], -expected_params.kp)
+ np.testing.assert_array_equal(actuator_pos.forcerange,
+ expected_params.forcerange)
+
+ @parameterized.parameters([
+ 'body_camera',
+ 'egocentric_camera',
+ 'head',
+ 'left_arm_root',
+ 'right_arm_root',
+ 'root_body',
+ ])
+ def test_get_element_property(self, name):
+ attribute_value = getattr(cmu_humanoid.CMUHumanoid(), name)
+ self.assertIsInstance(attribute_value, mjcf.Element)
+
+ @parameterized.parameters([
+ 'actuators',
+ 'bodies',
+ 'end_effectors',
+ 'marker_geoms',
+ 'mocap_joints',
+ 'observable_joints',
+ ])
+ def test_get_element_tuple_property(self, name):
+ attribute_value = getattr(cmu_humanoid.CMUHumanoid(), name)
+ self.assertNotEmpty(attribute_value)
+ for item in attribute_value:
+ self.assertIsInstance(item, mjcf.Element)
+
+ def test_set_name(self):
+ name = 'fred'
+ walker = cmu_humanoid.CMUHumanoid(name=name)
+ self.assertEqual(walker.mjcf_model.model, name)
+
+ def test_set_marker_rgba(self):
+ marker_rgba = (1., 0., 1., 0.5)
+ walker = cmu_humanoid.CMUHumanoid(marker_rgba=marker_rgba)
+ for marker_geom in walker.marker_geoms:
+ np.testing.assert_array_equal(marker_geom.rgba, marker_rgba)
+
+ @parameterized.parameters(
+ 'actuator_activation',
+ 'appendages_pos',
+ 'body_camera',
+ 'head_height',
+ 'sensors_torque',
+ )
+ def test_evaluate_observable(self, name):
+ walker = cmu_humanoid.CMUHumanoid()
+ observable = getattr(walker.observables, name)
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+ observation = observable(physics)
+ self.assertIsInstance(observation, (float, np.ndarray))
+
+ def test_proprioception(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ for item in walker.observables.proprioception:
+ self.assertIsInstance(item, observable_base.Observable)
+
+ def test_cmu_pose_to_actuation(self):
+ walker = cmu_humanoid.CMUHumanoidPositionControlled()
+ random_state = np.random.RandomState(123)
+
+ expected_actuation = random_state.uniform(-1, 1, len(walker.actuator_order))
+
+ cmu_limits = zip(*(joint.range for joint in walker.mocap_joints))
+ cmu_lower, cmu_upper = (np.array(limit) for limit in cmu_limits)
+ cmu_pose = cmu_lower + (cmu_upper - cmu_lower) * (
+ 1 + expected_actuation[walker.actuator_to_joint_order]) / 2
+
+ actual_actuation = walker.cmu_pose_to_actuation(cmu_pose)
+
+ np.testing.assert_allclose(actual_actuation, expected_actuation)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/walkers/fruitfly_v2.py b/dm_control/locomotion/walkers/fruitfly_v2.py
new file mode 100644
index 00000000..e25bc57f
--- /dev/null
+++ b/dm_control/locomotion/walkers/fruitfly_v2.py
@@ -0,0 +1,675 @@
+# Copyright 2023 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Fruit fly model."""
+
+import collections as col
+import os
+from typing import Sequence
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import legacy_base
+from dm_control.mujoco import wrapper as mj_wrapper
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.utils import transformations
+from dm_env import specs
+import numpy as np
+enums = mjbindings.enums
+mjlib = mjbindings.mjlib
+
+
+_XML_PATH = os.path.join(os.path.dirname(__file__),
+ 'assets/fruitfly_v2/fruitfly.xml')
+# === Constants.
+_SPAWN_POS = np.array((0, 0, 0.1278))
+# OrderedDict used to streamline enabling/disabling of action classes.
+_ACTION_CLASSES = col.OrderedDict(adhesion=0,
+ head=0,
+ mouth=0,
+ antennae=0,
+ wings=0,
+ abdomen=0,
+ legs=0,
+ user=0)
+
+
+def neg_quat(quat_a):
+ """Returns neg(quat_a)."""
+ quat_b = quat_a.copy()
+ quat_b[0] *= -1
+ return quat_b
+
+
+def mul_quat(quat_a, quat_b):
+ """Returns quat_a * quat_b."""
+ quat_c = np.zeros(4)
+ mjlib.mju_mulQuat(quat_c, quat_a, quat_b)
+ return quat_c
+
+
+def mul_jac_t_vec(physics, efc):
+ """Maps forces from constraint space to joint space."""
+ qfrc = np.zeros(physics.model.nv)
+ mjlib.mj_mulJacTVec(physics.model.ptr, physics.data.ptr, qfrc, efc)
+ return qfrc
+
+
+def rot_vec_quat(vec, quat):
+ """Rotates vector with quaternion."""
+ res = np.zeros(3)
+ mjlib.mju_rotVecQuat(res, vec, quat)
+ return res
+
+
+def any_substr_in_str(substrings: Sequence[str], string: str) -> bool:
+ """Checks if any of substrings is in string."""
+ return any(s in string for s in substrings)
+
+
+def body_quat_from_springrefs(body: 'mjcf.element') -> np.ndarray:
+ """Computes new body quat from all joint springrefs and current quat."""
+ joints = body.joint
+ if not joints:
+ return None
+ # Construct quaternions for all joint axes.
+ quats = []
+ for joint in joints:
+ theta = joint.springref or joint.dclass.joint.springref or 0
+ axis = joint.axis or joint.dclass.joint.axis
+ if axis is None:
+ axis = joint.dclass.parent.joint.axis
+ quats.append(np.hstack((np.cos(theta/2), np.sin(theta/2) * axis)))
+ # Compute the new orientation quaternion.
+ quat = np.array([1., 0, 0, 0])
+ for i in range(len(quats)):
+ quat = transformations.quat_mul(quats[-1-i], quat)
+ if body.quat is not None:
+ quat = transformations.quat_mul(body.quat, quat)
+ return quat
+
+
+def change_body_frame(body, frame_pos, frame_quat):
+ """Change the frame of a body while maintaining child locations."""
+ frame_pos = np.zeros(3) if frame_pos is None else frame_pos
+ frame_quat = np.array((1., 0, 0, 0)) if frame_quat is None else frame_quat
+ # Get frame transformation.
+ body_pos = np.zeros(3) if body.pos is None else body.pos
+ dpos = body_pos - frame_pos
+ body_quat = np.array((1., 0, 0, 0)) if body.quat is None else body.quat
+ dquat = mul_quat(neg_quat(frame_quat), body_quat)
+ # Translate and rotate the body to the new frame.
+ body.pos = frame_pos
+ body.quat = frame_quat
+ # Move all its children to their previous location.
+ for child in body.all_children():
+ if not hasattr(child, 'pos'):
+ continue
+ # Rotate:
+ if hasattr(child, 'quat'):
+ child_quat = np.array((1., 0, 0, 0)) if child.quat is None else child.quat
+ child.quat = mul_quat(dquat, child_quat)
+ # Translate, accounting for rotations.
+ child_pos = np.zeros(3) if child.pos is None else child.pos
+ pos_in_parent = rot_vec_quat(child_pos, body_quat) + dpos
+ child.pos = rot_vec_quat(pos_in_parent, neg_quat(frame_quat))
+
+
+#-------------------------------------------------------------------------------
+
+
+class FruitFly(legacy_base.Walker):
+ """A fruit fly model."""
+
+ def _build(self,
+ name: str = 'walker',
+ use_legs: bool = True,
+ use_wings: bool = False,
+ use_mouth: bool = False,
+ use_antennae: bool = False,
+ joint_filter: float = 0.01,
+ adhesion_filter: float = 0.01,
+ body_pitch_angle: float = 47.5,
+ stroke_plane_angle: float = 0.,
+ physics_timestep: float = 1e-4,
+ control_timestep: float = 2e-3,
+ num_user_actions: int = 0,
+ eye_camera_fovy: float = 150.,
+ eye_camera_size: int = 32,
+ ):
+ """Build a fruitfly walker.
+
+ Args:
+ name: Name of the walker.
+ use_legs: Whether to use or retract the legs.
+ use_wings: Whether to use or retract the wings.
+ use_mouth: Whether to use or retract the mouth.
+ use_antennae: Whether to use the antennae.
+ joint_filter: Timescale of filter for joint actuators. 0: disabled.
+ adhesion_filter: Timescale of filter for adhesion actuators. 0: disabled.
+ body_pitch_angle: Body pitch angle for initial flight pose, relative to
+ ground, degrees. 0: horizontal body position. Default value from
+ https://doi.org/10.1126/science.1248955
+ stroke_plane_angle: Angle of wing stroke plane for initial flight pose,
+ relative to ground, degrees. 0: horizontal stroke plane.
+ physics_timestep: Timestep of the simulation.
+ control_timestep: Timestep of the controller.
+ num_user_actions: Optional, number of additional actions for custom usage,
+ e.g. in before_step callback. The action range is [-1, 1]. 0: Not used.
+ eye_camera_fovy: Vertical field of view of the eye cameras, degrees. The
+ horizontal field of view is computed automatically given the window
+ size.
+ eye_camera_size: Size in pixels (height and width) of the eye cameras.
+ Height and width are assumed equal.
+ """
+ self._adhesion_filter = adhesion_filter
+ self._control_timestep = control_timestep
+ self._buffer_size = int(round(control_timestep/physics_timestep))
+ self._eye_camera_size = eye_camera_size
+ root = mjcf.from_path(_XML_PATH)
+ self._mjcf_root = root
+ if name:
+ self._mjcf_root.model = name
+
+ # Remove freejoint.
+ root.find('joint', 'free').remove()
+ # Set eye camera fovy.
+ root.find('camera', 'eye_right').fovy = eye_camera_fovy
+ root.find('camera', 'eye_left').fovy = eye_camera_fovy
+
+ # Identify actuator/body/joint/tendon classes by substrings in their names.
+ name_substr = {'adhesion': [],
+ 'head': ['head'],
+ 'mouth': ['rostrum', 'haustellum', 'labrum'],
+ 'antennae': ['antenna'],
+ 'wings': ['wing'],
+ 'abdomen': ['abdomen'],
+ 'legs': ['T1', 'T2', 'T3'],
+ 'user': []}
+
+ # === Retract disabled body parts and remove their actuators.
+
+ # Maybe retract and disable legs.
+ if not use_legs:
+ # Set orientation quaternions to retracted leg position.
+ leg_bodies = [b for b in root.find_all('body')
+ if any_substr_in_str(name_substr['legs'], b.name)]
+ for body in leg_bodies:
+ body.quat = body_quat_from_springrefs(body)
+ # Remove leg tendons and tendon actuators.
+ for tendon in root.find_all('tendon'):
+ if any_substr_in_str(name_substr['legs'], tendon.name):
+ # Assume tendon actuator names are the same as tendon names.
+ actuator = root.find('actuator', tendon.name)
+ if actuator is not None:
+ actuator.remove()
+ tendon.remove()
+ # Remove leg actuators and joints.
+ leg_joints = [j for j in root.find_all('joint')
+ if any_substr_in_str(name_substr['legs'], j.name)]
+ for joint in leg_joints:
+ # Assume joint actuator names are the same as joint names.
+ actuator = root.find('actuator', joint.name)
+ if actuator is not None:
+ actuator.remove()
+ self.observable_joints.remove(joint)
+ joint.remove()
+ # Remove leg adhesion actuators.
+ for actuator in root.find_all('actuator'):
+ if ('adhere' in actuator.name and
+ any_substr_in_str(name_substr['legs'], actuator.name)):
+ actuator.remove()
+
+ # Maybe retract and disable wings.
+ if not use_wings:
+ wing_joints = [j for j in root.find_all('joint')
+ if any_substr_in_str(name_substr['wings'], j.name)]
+ for joint in wing_joints:
+ root.find('actuator', joint.name).remove()
+ self.observable_joints.remove(joint)
+
+ # Maybe disable mouth.
+ if not use_mouth:
+ mouth_joints = [j for j in root.find_all('joint')
+ if any_substr_in_str(name_substr['mouth'], j.name)]
+ for joint in mouth_joints:
+ root.find('actuator', joint.name).remove()
+ self.observable_joints.remove(joint)
+ # Remove mouth adhesion actuators.
+ for actuator in root.find_all('actuator'):
+ if ('adhere' in actuator.name and
+ any_substr_in_str(name_substr['mouth'], actuator.name)):
+ actuator.remove()
+
+ # Maybe disable antennae.
+ if not use_antennae:
+ antenna_joints = [j for j in root.find_all('joint')
+ if any_substr_in_str(name_substr['antennae'], j.name)]
+ for joint in antenna_joints:
+ root.find('actuator', joint.name).remove()
+ self.observable_joints.remove(joint)
+
+ # === For flight, set body pitch angle and stroke plane angle.
+ if use_wings:
+ # == Set body pitch angle.
+ up_dir = root.find('site', 'hover_up_dir').quat
+ up_dir_angle = 2 * np.arccos(up_dir[0])
+ delta = np.deg2rad(body_pitch_angle) - up_dir_angle
+ dquat = np.array([np.cos(delta/2), 0, np.sin(delta/2), 0])
+ # Rotate up_dir to new angle.
+ up_dir[:] = mul_quat(dquat, up_dir)
+ # == Set stroke plane angle.
+ stroke_plane_angle = np.deg2rad(stroke_plane_angle)
+ stroke_plane_quat = np.array([np.cos(stroke_plane_angle/2), 0,
+ np.sin(stroke_plane_angle/2), 0])
+ for quat, wing in [(np.array([0., 0, 0, 1]), 'wing_left'),
+ (np.array([0., -1, 0, 0]), 'wing_right')]:
+ # Rotate wing-joint frame.
+ dquat = mul_quat(neg_quat(stroke_plane_quat), quat)
+ new_wing_quat = mul_quat(dquat, neg_quat(up_dir))
+ body = root.find('body', wing)
+ change_body_frame(body, body.pos, new_wing_quat)
+
+ # === Maybe change actuator dynamics to `filter`.
+ if joint_filter > 0:
+ for actuator in root.find_all('actuator'):
+ if actuator.tag != 'adhesion':
+ actuator.dyntype = 'filter'
+ actuator.dynprm = (joint_filter,)
+ if adhesion_filter > 0:
+ for actuator in root.find_all('actuator'):
+ if actuator.tag == 'adhesion':
+ actuator.dclass.parent.general.dyntype = 'filter'
+ actuator.dclass.parent.general.dynprm = (adhesion_filter,)
+
+ # === Get action-class indices into the MuJoCo control vector.
+ # Find all ctrl indices except adhesion.
+ self._ctrl_indices = _ACTION_CLASSES.copy()
+ names = [a.name for a in root.find_all('actuator')]
+ for act_class in self._ctrl_indices.keys():
+ indices = [i for i, name in enumerate(names)
+ if any_substr_in_str(name_substr[act_class], name)
+ and 'adhere' not in name]
+ self._ctrl_indices[act_class] = indices if indices else None
+ # Find adhesion ctrl indices.
+ indices = [i for i, name in enumerate(names) if 'adhere' in name]
+ self._ctrl_indices['adhesion'] = indices if indices else None
+
+ # === Count the number of actions in each action-class.
+ self._num_actions = _ACTION_CLASSES.copy()
+
+ # User actions, if any.
+ self._num_actions['user'] = num_user_actions
+
+ # The rest of action classes, including adhesion.
+ for act_class in self._num_actions.keys():
+ if self._ctrl_indices[act_class] is not None:
+ self._num_actions[act_class] = len(self._ctrl_indices[act_class])
+
+ # === Get action-class indices into the environment action vector.
+ self._action_indices = _ACTION_CLASSES.copy()
+ counter = 0
+ for act_class in _ACTION_CLASSES.keys():
+ if self._num_actions[act_class]:
+ indices = list(range(counter, counter + self._num_actions[act_class]))
+ self._action_indices[act_class] = indices
+ counter += self._num_actions[act_class]
+ else:
+ self._action_indices[act_class] = []
+
+ super()._build()
+
+ #-----------------------------------------------------------------------------
+
+ def initialize_episode(self, physics: 'mjcf.Physics',
+ random_state: np.random.RandomState):
+ """Set the walker."""
+ # Save the weight of the body (in Dyne i.e. gram*cm/s^2).
+ body_mass = physics.named.model.body_subtreemass['walker/thorax'] # gram.
+ self._weight = np.linalg.norm(physics.model.opt.gravity) * body_mass
+
+ #-----------------------------------------------------------------------------
+
+ @property
+ def upright_pose(self):
+ return base.WalkerPose(xpos=_SPAWN_POS)
+
+ @property
+ def weight(self):
+ return self._weight
+
+ @property
+ def adhesion_filter(self):
+ return self._adhesion_filter
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @composer.cached_property
+ def root_body(self):
+ """Return the body."""
+ return self.mjcf_model.find('body', 'thorax')
+
+ @composer.cached_property
+ def thorax(self):
+ """Return the thorax."""
+ return self.mjcf_model.find('body', 'thorax')
+
+ @composer.cached_property
+ def abdomen(self):
+ """Return the abdomen."""
+ return self.mjcf_model.find('body', 'abdomen')
+
+ @composer.cached_property
+ def head(self):
+ """Return the head."""
+ return self.mjcf_model.find('body', 'head')
+
+ @composer.cached_property
+ def head_site(self):
+ """Return the head."""
+ return self.mjcf_model.find('site', 'head')
+
+ @composer.cached_property
+ def observable_joints(self):
+ return self.mjcf_model.find_all('joint')
+
+ @composer.cached_property
+ def actuators(self):
+ return self.mjcf_model.find_all('actuator')
+
+ @composer.cached_property
+ def mocap_tracking_bodies(self):
+ # Which bodies to track?
+ body_names = (
+ 'thorax', 'abdomen', 'head',
+ 'claw_T1_left', 'claw_T1_right',
+ 'claw_T2_left', 'claw_T2_right',
+ 'claw_T3_left', 'claw_T3_right')
+ bodies = []
+ for body_name in body_names:
+ body = self.mjcf_model.find('body', body_name)
+ if body:
+ bodies.append(body)
+ return tuple(bodies)
+
+ @composer.cached_property
+ def end_effectors(self):
+ site_names = ('claw_T1_left', 'claw_T1_right',
+ 'claw_T2_left', 'claw_T2_right',
+ 'claw_T3_left', 'claw_T3_right')
+ sites = []
+ for site_name in site_names:
+ site = self.mjcf_model.find('site', site_name)
+ if site:
+ sites.append(site)
+ return tuple(sites)
+
+ @composer.cached_property
+ def appendages(self):
+ # wings? mouth? antennae?
+ additional_site_names = ('head',)
+ sites = list(self.end_effectors)
+ for site_name in additional_site_names:
+ sites.append(self.mjcf_model.find('site', site_name))
+ return tuple(sites)
+
+ def _build_observables(self):
+ return FruitFlyObservables(self, self._buffer_size, self._eye_camera_size)
+
+ @composer.cached_property
+ def left_eye(self):
+ """Return the left_eye camera."""
+ return self._mjcf_root.find('camera', 'eye_left')
+
+ @composer.cached_property
+ def right_eye(self):
+ """Return the right_eye camera."""
+ return self._mjcf_root.find('camera', 'eye_right')
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ """Required by legacy_base."""
+ return self._mjcf_root.find('camera', 'eye_right')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ """Return ground contact geoms."""
+ return (self._mjcf_root.find('geom', 'tarsal_claw_T1_left_collision'),
+ self._mjcf_root.find('geom', 'tarsal_claw_T1_right_collision'),
+ self._mjcf_root.find('geom', 'tarsal_claw_T2_left_collision'),
+ self._mjcf_root.find('geom', 'tarsal_claw_T2_right_collision'),
+ self._mjcf_root.find('geom', 'tarsal_claw_T3_left_collision'),
+ self._mjcf_root.find('geom', 'tarsal_claw_T3_right_collision'),
+ )
+
+ #-----------------------------------------------------------------------------
+
+ def apply_action(self, physics, action, random_state):
+ """Apply action to walker's actuators."""
+ del random_state
+ if not self.mjcf_model.find_all('actuator'):
+ return
+ # Apply MuJoCo actions.
+ ctrl = np.zeros(physics.model.nu)
+ for key, indices in self._action_indices.items():
+ if self._ctrl_indices[key] and indices:
+ ctrl[self._ctrl_indices[key]] = action[indices]
+ physics.set_control(ctrl)
+
+ #-----------------------------------------------------------------------------
+
+ def get_action_spec(self, physics):
+ """Returns a `BoundedArray` spec matching this walker's actuators."""
+ minimum = []
+ maximum = []
+
+ # MuJoCo actions.
+ indices = []
+ for key, _ in self._action_indices.items():
+ if self._ctrl_indices[key] and self._num_actions[key]:
+ indices.extend(self._ctrl_indices[key])
+ mj_minima, mj_maxima = physics.model.actuator_ctrlrange[indices].T
+ names = [physics.model.id2name(i, 'actuator') or str(i)
+ for i in indices]
+ names = [s.split('/')[-1] for s in names]
+ num_actions = len(indices)
+ minimum.extend(mj_minima)
+ maximum.extend(mj_maxima)
+
+ # User actions.
+ if self._num_actions['user']:
+ minimum.extend(self._num_actions['user'] * [-1.0])
+ maximum.extend(self._num_actions['user'] * [1.0])
+ names.extend([f'user_{i}' for i in range(self._num_actions['user'])])
+ num_actions += self._num_actions['user']
+
+ return specs.BoundedArray(shape=(num_actions,),
+ dtype=float,
+ minimum=np.asarray(minimum),
+ maximum=np.asarray(maximum),
+ name='\t'.join(names))
+
+#-------------------------------------------------------------------------------
+
+
+class FruitFlyObservables(legacy_base.WalkerObservables):
+ """Observables for the fruit fly."""
+
+ def __init__(self, walker, buffer_size, eye_camera_size):
+ self._walker = walker
+ self._buffer_size = buffer_size
+ self._eye_camera_size = eye_camera_size
+ super().__init__(walker)
+
+ @composer.observable
+ def thorax_height(self):
+ """Observe the thorax height."""
+ return observable.MJCFFeature('xpos', self._entity.thorax)[2]
+
+ @composer.observable
+ def abdomen_height(self):
+ """Observe the abdomen height."""
+ return observable.MJCFFeature('xpos', self._entity.abdomen)[2]
+
+ @composer.observable
+ def world_zaxis_hover(self):
+ """The world's z-vector in this Walker's torso frame."""
+ hover_up_dir = self._walker.mjcf_model.find('site', 'hover_up_dir')
+ return observable.MJCFFeature('xmat', hover_up_dir)[6:]
+
+ @composer.observable
+ def world_zaxis(self):
+ """The world's z-vector in this Walker's torso frame."""
+ return observable.MJCFFeature('xmat', self._entity.root_body)[6:]
+
+ @composer.observable
+ def world_zaxis_abdomen(self):
+ """The world's z-vector in this Walker's abdomen frame."""
+ return observable.MJCFFeature('xmat', self._entity.abdomen)[6:]
+
+ @composer.observable
+ def world_zaxis_head(self):
+ """The world's z-vector in this Walker's head frame."""
+ return observable.MJCFFeature('xmat', self._entity.head)[6:]
+
+ @composer.observable
+ def force(self):
+ """Force sensors."""
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.force,
+ buffer_size=self._buffer_size,
+ aggregator='mean')
+
+ @composer.observable
+ def touch(self):
+ """Touch sensors."""
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.touch,
+ buffer_size=self._buffer_size,
+ aggregator='mean')
+
+ @composer.observable
+ def accelerometer(self):
+ """Accelerometer readings."""
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.accelerometer,
+ buffer_size=self._buffer_size,
+ aggregator='mean')
+
+ @composer.observable
+ def gyro(self):
+ """Gyro readings."""
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.gyro,
+ buffer_size=self._buffer_size,
+ aggregator='mean')
+
+ @composer.observable
+ def velocimeter(self):
+ """Velocimeter readings."""
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.velocimeter,
+ buffer_size=self._buffer_size,
+ aggregator='mean')
+
+ @composer.observable
+ def actuator_activation(self):
+ """Observe the actuator activation."""
+ model = self._entity.mjcf_model
+ return observable.MJCFFeature('act', model.find_all('actuator'))
+
+ @composer.observable
+ def appendages_pos(self):
+ """Equivalent to `end_effectors_pos` but may include other appendages."""
+
+ def relative_pos_in_egocentric_frame(physics):
+ appendages = physics.bind(self._entity.appendages).xpos
+ torso_pos = physics.bind(self._entity.root_body).xpos
+ torso_mat = np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3))
+ return np.reshape(np.dot(appendages - torso_pos, torso_mat), -1)
+
+ return observable.Generic(relative_pos_in_egocentric_frame)
+
+ @composer.observable
+ def self_contact(self):
+ """Returns the sum of self-contact forces."""
+ def sum_body_contact_forces(physics):
+ walker_id = physics.model.name2id('walker/', 'body')
+ force = np.array((0.0))
+ for contact_id, contact in enumerate(physics.data.contact):
+ # Both geoms must be descendants of the thorax.
+ body1 = physics.model.geom_bodyid[contact.geom1]
+ body2 = physics.model.geom_bodyid[contact.geom2]
+ root1 = physics.model.body_rootid[body1]
+ root2 = physics.model.body_rootid[body2]
+ if not(root1 == walker_id and root2 == walker_id):
+ continue
+ contact_force, _ = physics.data.contact_force(contact_id)
+ force += np.linalg.norm(contact_force)
+ return force
+ return observable.Generic(sum_body_contact_forces,
+ buffer_size=self._buffer_size,
+ aggregator='mean')
+
+ @property
+ def vestibular(self):
+ """Return vestibular information."""
+ return [self.gyro, self.accelerometer,
+ self.velocimeter, self.world_zaxis]
+
+ @property
+ def proprioception(self):
+ """Return proprioceptive information."""
+ return [self.joints_pos, self.joints_vel,
+ self.actuator_activation]
+
+ @property
+ def orientation(self):
+ """Return orientation of world z-axis in local frame."""
+ return [self.world_zaxis, self.world_zaxis_abdomen, self.world_zaxis_head]
+
+ @composer.observable
+ def right_eye(self):
+ """Observable of the right_eye camera."""
+
+ if not hasattr(self, '_scene_options'):
+ # Render this walker's geoms.
+ self._scene_options = mj_wrapper.MjvOption()
+ cosmetic_geom_group = 1
+ self._scene_options.geomgroup[cosmetic_geom_group] = 1
+
+ return observable.MJCFCamera(self._entity.right_eye,
+ width=self._eye_camera_size,
+ height=self._eye_camera_size,
+ scene_option=self._scene_options)
+
+ @composer.observable
+ def left_eye(self):
+ """Observable of the left_eye camera."""
+
+ if not hasattr(self, '_scene_options'):
+ # Render this walker's geoms.
+ self._scene_options = mj_wrapper.MjvOption()
+ cosmetic_geom_group = 1
+ self._scene_options.geomgroup[cosmetic_geom_group] = 1
+
+ return observable.MJCFCamera(self._entity.left_eye,
+ width=self._eye_camera_size,
+ height=self._eye_camera_size,
+ scene_option=self._scene_options)
diff --git a/dm_control/locomotion/walkers/initializers/__init__.py b/dm_control/locomotion/walkers/initializers/__init__.py
new file mode 100644
index 00000000..90ebe470
--- /dev/null
+++ b/dm_control/locomotion/walkers/initializers/__init__.py
@@ -0,0 +1,65 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Initializers for the locomotion walkers."""
+
+import abc
+import numpy as np
+
+
+class WalkerInitializer(metaclass=abc.ABCMeta):
+ """The abstract base class for a walker initializer."""
+
+ @abc.abstractmethod
+ def initialize_pose(self, physics, walker, random_state):
+ raise NotImplementedError
+
+
+class UprightInitializer(WalkerInitializer):
+ """An initializer that uses the walker-declared upright pose."""
+
+ def initialize_pose(self, physics, walker, random_state):
+ all_joints_binding = physics.bind(walker.mjcf_model.find_all('joint'))
+ qpos, xpos, xquat = walker.upright_pose
+ if qpos is None:
+ walker.configure_joints(physics, all_joints_binding.qpos0)
+ else:
+ walker.configure_joints(physics, qpos)
+ walker.set_pose(physics, position=xpos, quaternion=xquat)
+ walker.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+
+class RandomlySampledInitializer(WalkerInitializer):
+ """An initializer that random selects between many initializers."""
+
+ def __init__(self, initializers):
+ self._initializers = initializers
+ self.num_initializers = len(initializers)
+
+ def initialize_pose(self, physics, walker, random_state):
+ random_initalizer_idx = np.random.randint(0, self.num_initializers)
+ self._initializers[random_initalizer_idx].initialize_pose(
+ physics, walker, random_state)
+
+
+class NoOpInitializer(WalkerInitializer):
+ """An initializer that does nothing."""
+
+ def initialize_pose(self, physics, walker, random_state):
+ pass
+
+
+
diff --git a/dm_control/locomotion/walkers/initializers/mocap.py b/dm_control/locomotion/walkers/initializers/mocap.py
new file mode 100644
index 00000000..adbf83ce
--- /dev/null
+++ b/dm_control/locomotion/walkers/initializers/mocap.py
@@ -0,0 +1,48 @@
+# Copyright 2021 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Initializers for walkers that use motion capture data."""
+
+from dm_control.locomotion.mocap import cmu_mocap_data
+from dm_control.locomotion.mocap import loader
+from dm_control.locomotion.walkers import initializers
+
+
+class CMUMocapInitializer(initializers.UprightInitializer):
+ """Initializer that uses data from a CMU mocap dataset.
+
+ Only suitable if walker matches the motion capture data.
+ """
+
+ def __init__(self, mocap_key='CMU_077_02', version='2019'):
+ """Load the trajectory."""
+ ref_path = cmu_mocap_data.get_path_for_cmu(version)
+ self._loader = loader.HDF5TrajectoryLoader(ref_path)
+ self._trajectory = self._loader.get_trajectory(mocap_key)
+
+ def initialize_pose(self, physics, walker, random_state):
+ super(CMUMocapInitializer, self).initialize_pose(
+ physics, walker, random_state)
+ random_time = (self._trajectory.start_time +
+ self._trajectory.dt * random_state.randint(
+ self._trajectory.num_steps))
+ (walker_timestep,) = self._trajectory.get_timestep_data(
+ random_time).walkers
+ physics.bind(walker.mocap_joints).qpos = walker_timestep.joints
+ physics.bind(walker.mocap_joints).qvel = (
+ walker_timestep.joints_velocity)
+ walker.set_velocity(physics,
+ velocity=walker_timestep.velocity,
+ angular_velocity=walker_timestep.angular_velocity)
diff --git a/dm_control/locomotion/walkers/jumping_ball.py b/dm_control/locomotion/walkers/jumping_ball.py
new file mode 100644
index 00000000..058ee472
--- /dev/null
+++ b/dm_control/locomotion/walkers/jumping_ball.py
@@ -0,0 +1,157 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Walkers based on an actuated jumping ball."""
+
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.locomotion.walkers import legacy_base
+import numpy as np
+
+_ASSETS_PATH = os.path.join(os.path.dirname(__file__),
+ 'assets/jumping_ball')
+
+
+class JumpingBallWithHead(legacy_base.Walker):
+ """A rollable and jumpable ball with a head."""
+
+ def _build(self, name='walker', marker_rgba=None, camera_control=False,
+ initializer=None, add_ears=False, camera_height=None):
+ """Build a JumpingBallWithHead.
+
+ Args:
+ name: name of the walker.
+ marker_rgba: RGBA value set to walker.marker_geoms to distinguish between
+ walkers (in multi-agent setting).
+ camera_control: If `True`, the walker exposes two additional actuated
+ degrees of freedom to control the egocentric camera height and tilt.
+ initializer: (Optional) A `WalkerInitializer` object.
+ add_ears: a boolean. Same as the nose above but the red/blue balls are
+ placed to the left/right of the agent. Better for egocentric vision.
+ camera_height: A float specifying the height of the camera, or `None` if
+ the camera height should be left as specified in the XML model.
+ """
+ super()._build(initializer=initializer)
+ self._mjcf_root = self._mjcf_root = mjcf.from_path(self._xml_path)
+
+ if name:
+ self._mjcf_root.model = name
+
+ if camera_height is not None:
+ self._mjcf_root.find('body', 'egocentric_camera').pos[2] = camera_height
+
+ if add_ears:
+ # Large ears
+ head = self._mjcf_root.find('body', 'head_body')
+ head.add('site', type='sphere', size=(.26,),
+ pos=(.22, 0, 0),
+ rgba=(.7, 0, 0, 1))
+ head.add('site', type='sphere', size=(.26,),
+ pos=(-.22, 0, 0),
+ rgba=(0, 0, .7, 1))
+ # Set corresponding marker color if specified.
+ if marker_rgba is not None:
+ for geom in self.marker_geoms:
+ geom.set_attributes(rgba=marker_rgba)
+
+ self._root_joints = None
+ self._camera_control = camera_control
+ if not camera_control:
+ for name in ('camera_height', 'camera_tilt'):
+ self._mjcf_root.find('actuator', name).remove()
+ self._mjcf_root.find('joint', name).remove()
+
+ @property
+ def _xml_path(self):
+ return os.path.join(_ASSETS_PATH, 'jumping_ball_with_head.xml')
+
+ @property
+ def marker_geoms(self):
+ return [self._mjcf_root.find('geom', 'head')]
+
+ def create_root_joints(self, attachment_frame):
+ root_class = self._mjcf_root.find('default', 'root')
+ root_x = attachment_frame.add(
+ 'joint', name='root_x', type='slide', axis=[1, 0, 0], dclass=root_class)
+ root_y = attachment_frame.add(
+ 'joint', name='root_y', type='slide', axis=[0, 1, 0], dclass=root_class)
+ root_z = attachment_frame.add(
+ 'joint', name='root_z', type='slide', axis=[0, 0, 1], dclass=root_class)
+ self._root_joints = [root_x, root_y, root_z]
+
+ def set_pose(self, physics, position=None, quaternion=None):
+ if position is not None:
+ if self._root_joints is not None:
+ physics.bind(self._root_joints).qpos = position
+ else:
+ super().set_pose(physics, position, quaternion=None)
+ physics.bind(self._mjcf_root.find_all('joint')).qpos = 0.
+ if quaternion is not None:
+ # This walker can only rotate along the z-axis, so we extract only that
+ # component from the quaternion.
+ z_angle = np.arctan2(
+ 2 * (quaternion[0] * quaternion[3] + quaternion[1] * quaternion[2]),
+ 1 - 2 * (quaternion[2] ** 2 + quaternion[3] ** 2))
+ physics.bind(self._mjcf_root.find('joint', 'steer')).qpos = z_angle
+
+ def initialize_episode(self, physics, unused_random_state):
+ # gravity compensation
+ if self._camera_control:
+ gravity = np.hstack([physics.model.opt.gravity, [0, 0, 0]])
+ comp_bodies = physics.bind(self._mjcf_root.find('body',
+ 'egocentric_camera'))
+ comp_bodies.xfrc_applied = -gravity * comp_bodies.mass[..., None]
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @composer.cached_property
+ def actuators(self):
+ return self._mjcf_root.find_all('actuator')
+
+ @composer.cached_property
+ def root_body(self):
+ return self._mjcf_root.find('body', 'head_body')
+
+ @composer.cached_property
+ def end_effectors(self):
+ return [self._mjcf_root.find('body', 'head_body')]
+
+ @composer.cached_property
+ def observable_joints(self):
+ return [self._mjcf_root.find('joint', 'kick')]
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ return (self._mjcf_root.find('geom', 'shell'),)
+
+
+class RollingBallWithHead(JumpingBallWithHead):
+ """A rollable ball with a head."""
+
+ def _build(self, **kwargs):
+ super()._build(**kwargs)
+ self._mjcf_root.find('actuator', 'kick').remove()
+ self._mjcf_root.find('joint', 'kick').remove()
+
+ @composer.cached_property
+ def observable_joints(self):
+ return []
diff --git a/dm_control/locomotion/walkers/jumping_ball_test.py b/dm_control/locomotion/walkers/jumping_ball_test.py
new file mode 100644
index 00000000..5b19629d
--- /dev/null
+++ b/dm_control/locomotion/walkers/jumping_ball_test.py
@@ -0,0 +1,114 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the Jumping Ball."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base as observable_base
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.walkers import jumping_ball
+import numpy as np
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.005
+
+
+def _get_jumping_ball_corridor_physics():
+ walker = jumping_ball.JumpingBallWithHead()
+ arena = corr_arenas.EmptyCorridor()
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(5, 0, 0),
+ walker_spawn_rotation=0,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ env = composer.Environment(
+ time_limit=30,
+ task=task,
+ strip_singleton_obs_buffer_dim=True)
+
+ return walker, env
+
+
+class JumpingBallWithHeadTest(parameterized.TestCase):
+
+ def test_can_compile_and_step_simulation(self):
+ _, env = _get_jumping_ball_corridor_physics()
+ physics = env.physics
+ for _ in range(100):
+ physics.step()
+
+ @parameterized.parameters([
+ 'egocentric_camera',
+ ])
+ def test_get_element_property(self, name):
+ attribute_value = getattr(jumping_ball.JumpingBallWithHead(), name)
+ self.assertIsInstance(attribute_value, mjcf.Element)
+
+ @parameterized.parameters([
+ 'actuators',
+ 'end_effectors',
+ 'observable_joints',
+ ])
+ def test_get_element_tuple_property(self, name):
+ attribute_value = getattr(jumping_ball.JumpingBallWithHead(), name)
+ self.assertNotEmpty(attribute_value)
+ for item in attribute_value:
+ self.assertIsInstance(item, mjcf.Element)
+
+ def test_set_name(self):
+ name = 'fred'
+ walker = jumping_ball.JumpingBallWithHead(name=name)
+ self.assertEqual(walker.mjcf_model.model, name)
+
+ @parameterized.parameters(
+ 'sensors_velocimeter',
+ 'world_zaxis',
+ )
+ def test_evaluate_observable(self, name):
+ walker, env = _get_jumping_ball_corridor_physics()
+ physics = env.physics
+ observable = getattr(walker.observables, name)
+ observation = observable(physics)
+ self.assertIsInstance(observation, (float, np.ndarray))
+
+ def test_proprioception(self):
+ walker = jumping_ball.JumpingBallWithHead()
+ for item in walker.observables.proprioception:
+ self.assertIsInstance(item, observable_base.Observable)
+
+ @parameterized.parameters(
+ dict(camera_control=True, add_ears=True, camera_height=1.),
+ dict(camera_control=True, add_ears=False, camera_height=1.),
+ dict(camera_control=False, add_ears=True, camera_height=1.),
+ dict(camera_control=False, add_ears=False, camera_height=1.),
+ dict(camera_control=True, add_ears=True, camera_height=None),
+ dict(camera_control=True, add_ears=False, camera_height=None),
+ dict(camera_control=False, add_ears=True, camera_height=None),
+ dict(camera_control=False, add_ears=False, camera_height=None),
+ )
+ def test_instantiation(self, camera_control, add_ears, camera_height):
+ jumping_ball.JumpingBallWithHead(camera_control=camera_control,
+ add_ears=add_ears,
+ camera_height=camera_height)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/walkers/legacy_base.py b/dm_control/locomotion/walkers/legacy_base.py
new file mode 100644
index 00000000..59ff9442
--- /dev/null
+++ b/dm_control/locomotion/walkers/legacy_base.py
@@ -0,0 +1,380 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for Walkers."""
+
+import abc
+
+from dm_control import composer
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import initializers
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+import numpy as np
+
+_RANGEFINDER_SCALE = 10.0
+_TOUCH_THRESHOLD = 1e-3
+
+
+class Walker(base.Walker):
+ """Legacy base class for Walker robots."""
+
+ def _build(self, initializer=None):
+ self._end_effectors_pos_sensors = []
+ self._obs_on_other = {}
+ try:
+ self._initializers = tuple(initializer)
+ except TypeError:
+ self._initializers = (initializer or initializers.UprightInitializer(),)
+
+ @property
+ def upright_pose(self):
+ return base.WalkerPose()
+
+ def _build_observables(self):
+ return WalkerObservables(self)
+
+ def reinitialize_pose(self, physics, random_state):
+ for initializer in self._initializers:
+ initializer.initialize_pose(physics, self, random_state)
+
+ def aliveness(self, physics):
+ """A measure of the aliveness of the walker.
+
+ Aliveness measure could be used for deciding on termination (ant flipped
+ over and it's impossible for it to recover), or used as a shaping reward
+ to maintain an alive pose that we desired (humanoids remaining upright).
+
+ Args:
+ physics: an instance of `Physics`.
+
+ Returns:
+ a `float` in the range of [-1., 0.] where -1 means not alive and 0. means
+ alive. In walkers for which the concept of aliveness does not make sense,
+ the default implementation is to always return 0.0.
+ """
+ return 0.
+
+ @property
+ @abc.abstractmethod
+ def ground_contact_geoms(self):
+ """Geoms in this walker that are expected to be in contact with the ground.
+
+ This property is used by some tasks to determine contact-based failure
+ termination. It should only contain geoms that are expected to be in
+ contact with the ground during "normal" locomotion. For example, for a
+ humanoid model, this property would be expected to contain only the geoms
+ that make up the two feet.
+
+ Note that certain specialized tasks may also allow geoms that are not listed
+ here to be in contact with the ground. For example, a humanoid cartwheel
+ task would also allow the hands to touch the ground in addition to the feet.
+ """
+ raise NotImplementedError
+
+ def after_compile(self, physics, unused_random_state):
+ super().after_compile(physics, unused_random_state)
+ self._end_effector_geom_ids = set()
+ for eff_body in self.end_effectors:
+ eff_geom = eff_body.find_all('geom')
+ self._end_effector_geom_ids |= set(physics.bind(eff_geom).element_id)
+ self._body_geom_ids = set(
+ physics.bind(geom).element_id
+ for geom in self.mjcf_model.find_all('geom'))
+ self._body_geom_ids.difference_update(self._end_effector_geom_ids)
+
+ @property
+ def end_effector_geom_ids(self):
+ return self._end_effector_geom_ids
+
+ @property
+ def body_geom_ids(self):
+ return self._body_geom_ids
+
+ @property
+ def obs_on_other(self):
+ return self._obs_on_other
+
+ def end_effector_contacts(self, physics):
+ """Collect the contacts with the end effectors.
+
+ This function returns any contacts being made with any of the end effectors,
+ both the other geom with which contact is being made as well as the
+ magnitude.
+
+ Args:
+ physics: an instance of `Physics`.
+
+ Returns:
+ a dict with as key a tuple of geom ids, of which one is an end effector,
+ and as value the total magnitude of all contacts between these geoms
+ """
+ return self.collect_contacts(physics, self._end_effector_geom_ids)
+
+ def body_contacts(self, physics):
+ """Collect the contacts with the body.
+
+ This function returns any contacts being made with any of body geoms, except
+ the end effectors, both the other geom with which contact is being made as
+ well as the magnitude.
+
+ Args:
+ physics: an instance of `Physics`.
+
+ Returns:
+ a dict with as key a tuple of geom ids, of which one is a body geom,
+ and as value the total magnitude of all contacts between these geoms
+ """
+ return self.collect_contacts(physics, self._body_geom_ids)
+
+ def collect_contacts(self, physics, geom_ids):
+ contacts = {}
+ forcetorque = np.zeros(6)
+ for i, contact in enumerate(physics.data.contact):
+ if ((contact.geom1 in geom_ids) or
+ (contact.geom2 in geom_ids)) and contact.dist < contact.includemargin:
+ mjlib.mj_contactForce(physics.model.ptr, physics.data.ptr, i,
+ forcetorque)
+ contacts[(contact.geom1, contact.geom2)] = (forcetorque[0]
+ + contacts.get(
+ (contact.geom1,
+ contact.geom2), 0.))
+ return contacts
+
+ @property
+ @abc.abstractmethod
+ def end_effectors(self):
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def egocentric_camera(self):
+ raise NotImplementedError
+
+ @composer.cached_property
+ def touch_sensors(self):
+ return self._mjcf_root.sensor.get_children('touch')
+
+ @property
+ def prev_action(self):
+ """Returns the actuation actions applied in the previous step.
+
+ Concrete walker implementations should provide caching mechanism themselves
+ in order to access this observable (for example, through `apply_action`).
+ """
+ raise NotImplementedError
+
+ def after_substep(self, physics, random_state):
+ del random_state # Unused.
+ # As of MuJoCo v2.0, updates to `mjData->subtree_linvel` will be skipped
+ # unless these quantities are needed by the simulation. We need these in
+ # order to calculate `torso_{x,y}vel`, so we therefore call `mj_subtreeVel`
+ # explicitly.
+ # TODO(b/123065920): Consider using a `subtreelinvel` sensor instead.
+ mjlib.mj_subtreeVel(physics.model.ptr, physics.data.ptr)
+
+ @composer.cached_property
+ def mocap_joints(self):
+ return tuple(self.mjcf_model.find_all('joint'))
+
+ def actuator_force(self, physics):
+ return physics.bind(self.observable_joints).qfrc_actuator
+
+ @composer.cached_property
+ def mocap_to_observable_joint_order(self):
+ mocap_to_obs = [self.mocap_joints.index(j) for j in self.observable_joints]
+ return mocap_to_obs
+
+ @composer.cached_property
+ def observable_to_mocap_joint_order(self):
+ obs_to_mocap = [self.observable_joints.index(j) for j in self.mocap_joints]
+ return obs_to_mocap
+
+ @property
+ def end_effectors_pos_sensors(self):
+ return self._end_effectors_pos_sensors
+
+
+class WalkerObservables(base.WalkerObservables):
+ """Legacy base class for Walker obserables."""
+
+ @composer.observable
+ def joints_vel(self):
+ return observable.MJCFFeature('qvel', self._entity.observable_joints)
+
+ @composer.observable
+ def body_height(self):
+ return observable.MJCFFeature('xpos', self._entity.root_body)[2]
+
+ @composer.observable
+ def end_effectors_pos(self):
+ """Position of end effectors relative to torso, in the egocentric frame."""
+ self._entity.end_effectors_pos_sensors[:] = []
+ for effector in self._entity.end_effectors:
+ objtype = effector.tag
+ if objtype == 'body':
+ objtype = 'xbody'
+ self._entity.end_effectors_pos_sensors.append(
+ self._entity.mjcf_model.sensor.add(
+ 'framepos', name=effector.name + '_end_effector',
+ objtype=objtype, objname=effector,
+ reftype='xbody', refname=self._entity.root_body))
+ def relative_pos_in_egocentric_frame(physics):
+ return np.reshape(
+ physics.bind(self._entity.end_effectors_pos_sensors).sensordata, -1)
+ return observable.Generic(relative_pos_in_egocentric_frame)
+
+ @composer.observable
+ def world_zaxis(self):
+ """The world's z-vector in this Walker's torso frame."""
+ return observable.MJCFFeature('xmat', self._entity.root_body)[6:]
+
+ @composer.observable
+ def sensors_velocimeter(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.velocimeter)
+
+ @composer.observable
+ def sensors_force(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.force)
+
+ @composer.observable
+ def sensors_torque(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.torque)
+
+ @composer.observable
+ def sensors_touch(self):
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.sensor.touch,
+ corruptor=
+ lambda v, random_state: np.array(v > _TOUCH_THRESHOLD, dtype=float))
+
+ @composer.observable
+ def sensors_rangefinder(self):
+ def tanh_rangefinder(physics):
+ raw = physics.bind(self._entity.mjcf_model.sensor.rangefinder).sensordata
+ raw = np.array(raw)
+ raw[raw == -1.0] = np.inf
+ return _RANGEFINDER_SCALE * np.tanh(raw / _RANGEFINDER_SCALE)
+ return observable.Generic(tanh_rangefinder)
+
+ @composer.observable
+ def egocentric_camera(self):
+ return observable.MJCFCamera(self._entity.egocentric_camera,
+ width=64, height=64)
+
+ @composer.observable
+ def position(self):
+ return observable.MJCFFeature('xpos', self._entity.root_body)
+
+ @composer.observable
+ def orientation(self):
+ return observable.MJCFFeature('xmat', self._entity.root_body)
+
+ def add_egocentric_vector(self,
+ name,
+ world_frame_observable,
+ enabled=True,
+ origin_callable=None,
+ **kwargs):
+
+ def _egocentric(physics, origin_callable=origin_callable):
+ vec = world_frame_observable.observation_callable(physics)()
+ origin_callable = origin_callable or (lambda physics: np.zeros(vec.size))
+ delta = vec - origin_callable(physics)
+ return self._entity.transform_vec_to_egocentric_frame(physics, delta)
+
+ self._observables[name] = observable.Generic(_egocentric, **kwargs)
+ self._observables[name].enabled = enabled
+
+ def add_egocentric_xmat(self, name, xmat_observable, enabled=True, **kwargs):
+
+ def _egocentric(physics):
+ return self._entity.transform_xmat_to_egocentric_frame(
+ physics,
+ xmat_observable.observation_callable(physics)())
+
+ self._observables[name] = observable.Generic(_egocentric, **kwargs)
+ self._observables[name].enabled = enabled
+
+ # Semantic groupings of Walker observables.
+ def _collect_from_attachments(self, attribute_name):
+ out = []
+ for entity in self._entity.iter_entities(exclude_self=True):
+ out.extend(getattr(entity.observables, attribute_name, []))
+ return out
+
+ @property
+ def proprioception(self):
+ return ([self.joints_pos, self.joints_vel,
+ self.body_height, self.end_effectors_pos, self.world_zaxis] +
+ self._collect_from_attachments('proprioception'))
+
+ @property
+ def kinematic_sensors(self):
+ return ([self.sensors_gyro, self.sensors_velocimeter,
+ self.sensors_accelerometer] +
+ self._collect_from_attachments('kinematic_sensors'))
+
+ @property
+ def dynamic_sensors(self):
+ return ([self.sensors_force, self.sensors_torque, self.sensors_touch] +
+ self._collect_from_attachments('dynamic_sensors'))
+
+ # Convenience observables for defining rewards and terminations.
+ @composer.observable
+ def veloc_strafe(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[1]
+
+ @composer.observable
+ def veloc_up(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[2]
+
+ @composer.observable
+ def veloc_forward(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[0]
+
+ @composer.observable
+ def gyro_backward_roll(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.gyro)[0]
+
+ @composer.observable
+ def gyro_rightward_roll(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.gyro)[1]
+
+ @composer.observable
+ def gyro_anticlockwise_spin(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.gyro)[2]
+
+ @composer.observable
+ def torso_xvel(self):
+ return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[0]
+
+ @composer.observable
+ def torso_yvel(self):
+ return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[1]
+
+ @composer.observable
+ def prev_action(self):
+ return observable.Generic(lambda _: self._entity.prev_action)
diff --git a/dm_control/locomotion/walkers/rescale.py b/dm_control/locomotion/walkers/rescale.py
new file mode 100644
index 00000000..f1a5226e
--- /dev/null
+++ b/dm_control/locomotion/walkers/rescale.py
@@ -0,0 +1,60 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Function to rescale the walkers."""
+
+
+from dm_control import mjcf
+
+
+def rescale_subtree(body, position_factor, size_factor):
+ """Recursively rescales an entire subtree of an MJCF model."""
+ for child in body.all_children():
+ if child.tag == 'sensor':
+ continue
+ if getattr(child, 'fromto', None) is not None:
+ new_pos = position_factor * 0.5 * (child.fromto[3:] + child.fromto[:3])
+ new_size = size_factor * 0.5 * (child.fromto[3:] - child.fromto[:3])
+ child.fromto[:3] = new_pos - new_size
+ child.fromto[3:] = new_pos + new_size
+ if getattr(child, 'pos', None) is not None:
+ child.pos *= position_factor
+ if getattr(child, 'size', None) is not None:
+ child.size *= size_factor
+ if child.tag == 'body' or child.tag == 'worldbody':
+ rescale_subtree(child, position_factor, size_factor)
+
+
+def rescale_humanoid(walker, position_factor, size_factor=None, mass=None):
+ """Rescales a humanoid walker's lengths, sizes, and masses."""
+ body = walker.mjcf_model.find('body', 'root')
+ subtree_root = body.parent
+ if size_factor is None:
+ size_factor = position_factor
+ rescale_subtree(subtree_root, position_factor, size_factor)
+
+ if mass is not None:
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model.root_model)
+ current_mass = physics.bind(walker.root_body).subtreemass
+ mass_factor = mass / current_mass
+ for body in walker.root_body.find_all('body'):
+ inertial = getattr(body, 'inertial', None)
+ if inertial:
+ inertial.mass *= mass_factor
+ for geom in walker.root_body.find_all('geom'):
+ if geom.mass is not None:
+ geom.mass *= mass_factor
+ else:
+ current_density = geom.density if geom.density is not None else 1000
+ geom.density = current_density * mass_factor
diff --git a/dm_control/locomotion/walkers/rescale_test.py b/dm_control/locomotion/walkers/rescale_test.py
new file mode 100644
index 00000000..f125dc34
--- /dev/null
+++ b/dm_control/locomotion/walkers/rescale_test.py
@@ -0,0 +1,61 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for rescaling bodies."""
+
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.walkers import rescale
+import numpy as np
+
+
+class RescaleTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+
+ # build a simple three-link chain with an endpoint site
+ self._mjcf_model = mjcf.RootElement()
+ body = self._mjcf_model.worldbody.add('body', pos=[0, 0, 0])
+ body.add('geom', type='capsule', fromto=[0, 0, 0, 0, 0, -0.4], size=[0.06])
+ body.add('joint', type='ball')
+ body = body.add('body', pos=[0, 0, -0.5])
+ body.add('geom', type='capsule', pos=[0, 0, -0.15], size=[0.06, 0.15])
+ body.add('joint', type='ball')
+ body = body.add('body', pos=[0, 0, -0.4])
+ body.add('geom', type='capsule', fromto=[0, 0, 0, 0.3, 0, -0.4],
+ size=[0.06])
+ body.add('joint', type='ball')
+ body.add('site', name='endpoint', type='sphere', pos=[0.3, 0, -0.4],
+ size=[0.1])
+
+ def test_rescale(self):
+ # verify endpoint is where expected
+ physics = mjcf.Physics.from_mjcf_model(self._mjcf_model)
+ np.testing.assert_allclose(physics.named.data.site_xpos['endpoint'],
+ np.array([0.3, 0., -1.3]), atol=1e-15)
+
+ # rescale chain and verify endpoint is where expected after modification
+ subtree_root = self._mjcf_model
+ position_factor = .5
+ size_factor = .5
+ rescale.rescale_subtree(subtree_root, position_factor, size_factor)
+ physics = mjcf.Physics.from_mjcf_model(self._mjcf_model)
+ np.testing.assert_allclose(physics.named.data.site_xpos['endpoint'],
+ np.array([0.15, 0., -0.65]), atol=1e-15)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/walkers/rodent.py b/dm_control/locomotion/walkers/rodent.py
new file mode 100644
index 00000000..b9d29b10
--- /dev/null
+++ b/dm_control/locomotion/walkers/rodent.py
@@ -0,0 +1,334 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A Rodent walker."""
+
+import os
+import re
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import legacy_base
+from dm_control.mujoco import wrapper as mj_wrapper
+import numpy as np
+
+_XML_PATH = os.path.join(os.path.dirname(__file__),
+ 'assets/rodent.xml')
+
+_RAT_MOCAP_JOINTS = [
+ 'vertebra_1_extend', 'vertebra_2_bend', 'vertebra_3_twist',
+ 'vertebra_4_extend', 'vertebra_5_bend', 'vertebra_6_twist',
+ 'hip_L_supinate', 'hip_L_abduct', 'hip_L_extend', 'knee_L', 'ankle_L',
+ 'toe_L', 'hip_R_supinate', 'hip_R_abduct', 'hip_R_extend', 'knee_R',
+ 'ankle_R', 'toe_R', 'vertebra_C1_extend', 'vertebra_C1_bend',
+ 'vertebra_C2_extend', 'vertebra_C2_bend', 'vertebra_C3_extend',
+ 'vertebra_C3_bend', 'vertebra_C4_extend', 'vertebra_C4_bend',
+ 'vertebra_C5_extend', 'vertebra_C5_bend', 'vertebra_C6_extend',
+ 'vertebra_C6_bend', 'vertebra_C7_extend', 'vertebra_C9_bend',
+ 'vertebra_C11_extend', 'vertebra_C13_bend', 'vertebra_C15_extend',
+ 'vertebra_C17_bend', 'vertebra_C19_extend', 'vertebra_C21_bend',
+ 'vertebra_C23_extend', 'vertebra_C25_bend', 'vertebra_C27_extend',
+ 'vertebra_C29_bend', 'vertebra_cervical_5_extend',
+ 'vertebra_cervical_4_bend', 'vertebra_cervical_3_twist',
+ 'vertebra_cervical_2_extend', 'vertebra_cervical_1_bend',
+ 'vertebra_axis_twist', 'vertebra_atlant_extend', 'atlas', 'mandible',
+ 'scapula_L_supinate', 'scapula_L_abduct', 'scapula_L_extend', 'shoulder_L',
+ 'shoulder_sup_L', 'elbow_L', 'wrist_L', 'finger_L', 'scapula_R_supinate',
+ 'scapula_R_abduct', 'scapula_R_extend', 'shoulder_R', 'shoulder_sup_R',
+ 'elbow_R', 'wrist_R', 'finger_R'
+]
+
+
+_UPRIGHT_POS = (0.0, 0.0, 0.0)
+_UPRIGHT_QUAT = (1., 0., 0., 0.)
+_TORQUE_THRESHOLD = 60
+
+
+class Rat(legacy_base.Walker):
+ """A position-controlled rat with control range scaled to [-1, 1]."""
+
+ def _build(self,
+ params=None,
+ name='walker',
+ torque_actuators=False,
+ foot_mods=False,
+ initializer=None):
+ self.params = params
+ self._mjcf_root = mjcf.from_path(_XML_PATH)
+ if name:
+ self._mjcf_root.model = name
+
+ self.body_sites = []
+ super()._build(initializer=initializer)
+
+ # modify actuators
+ if torque_actuators:
+ for actuator in self._mjcf_root.find_all('actuator'):
+ actuator.gainprm = [actuator.forcerange[1]]
+ del actuator.biastype
+ del actuator.biasprm
+
+ # modify ankle and toe limits
+ if foot_mods:
+ self._mjcf_root.find('default', 'ankle').joint.range = [-0.1, 2.]
+ self._mjcf_root.find('default', 'toe').joint.range = [-0.7, 0.87]
+
+ @property
+ def upright_pose(self):
+ """Reset pose to upright position."""
+ return base.WalkerPose(xpos=_UPRIGHT_POS, xquat=_UPRIGHT_QUAT)
+
+ @property
+ def mjcf_model(self):
+ """Return the model root."""
+ return self._mjcf_root
+
+ @composer.cached_property
+ def actuators(self):
+ """Return all actuators."""
+ return tuple(self._mjcf_root.find_all('actuator'))
+
+ @composer.cached_property
+ def root_body(self):
+ """Return the body."""
+ return self._mjcf_root.find('body', 'torso')
+
+ @composer.cached_property
+ def pelvis_body(self):
+ """Return the body."""
+ return self._mjcf_root.find('body', 'pelvis')
+
+ @composer.cached_property
+ def head(self):
+ """Return the head."""
+ return self._mjcf_root.find('body', 'skull')
+
+ @composer.cached_property
+ def left_arm_root(self):
+ """Return the left arm."""
+ return self._mjcf_root.find('body', 'scapula_L')
+
+ @composer.cached_property
+ def right_arm_root(self):
+ """Return the right arm."""
+ return self._mjcf_root.find('body', 'scapula_R')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ """Return ground contact geoms."""
+ return tuple(
+ self._mjcf_root.find('body', 'foot_L').find_all('geom') +
+ self._mjcf_root.find('body', 'foot_R').find_all('geom') +
+ self._mjcf_root.find('body', 'hand_L').find_all('geom') +
+ self._mjcf_root.find('body', 'hand_R').find_all('geom') +
+ self._mjcf_root.find('body', 'vertebra_C1').find_all('geom')
+ )
+
+ @composer.cached_property
+ def standing_height(self):
+ """Return standing height."""
+ return self.params['_STAND_HEIGHT']
+
+ @composer.cached_property
+ def end_effectors(self):
+ """Return end effectors."""
+ return (self._mjcf_root.find('body', 'lower_arm_R'),
+ self._mjcf_root.find('body', 'lower_arm_L'),
+ self._mjcf_root.find('body', 'foot_R'),
+ self._mjcf_root.find('body', 'foot_L'))
+
+ @composer.cached_property
+ def observable_joints(self):
+ """Return observable joints."""
+ return tuple(actuator.joint
+ for actuator in self.actuators # This lint is mistaken; pylint: disable=not-an-iterable
+ if actuator.joint is not None)
+
+ @composer.cached_property
+ def observable_tendons(self):
+ return self._mjcf_root.find_all('tendon')
+
+ @composer.cached_property
+ def mocap_joints(self):
+ return tuple(
+ self._mjcf_root.find('joint', name) for name in _RAT_MOCAP_JOINTS)
+
+ @composer.cached_property
+ def mocap_joint_order(self):
+ return tuple([jnt.name for jnt in self.mocap_joints]) # This lint is mistaken; pylint: disable=not-an-iterable
+
+ @composer.cached_property
+ def bodies(self):
+ """Return all bodies."""
+ return tuple(self._mjcf_root.find_all('body'))
+
+ @composer.cached_property
+ def mocap_tracking_bodies(self):
+ """Return bodies for mocap comparison."""
+ return tuple(body for body in self._mjcf_root.find_all('body')
+ if not re.match(r'(vertebra|hand|toe)', body.name))
+
+ @composer.cached_property
+ def primary_joints(self):
+ """Return primary (non-vertebra) joints."""
+ return tuple(jnt for jnt in self._mjcf_root.find_all('joint')
+ if 'vertebra' not in jnt.name)
+
+ @composer.cached_property
+ def vertebra_joints(self):
+ """Return vertebra joints."""
+ return tuple(jnt for jnt in self._mjcf_root.find_all('joint')
+ if 'vertebra' in jnt.name)
+
+ @composer.cached_property
+ def primary_joint_order(self):
+ joint_names = self.mocap_joint_order
+ primary_names = tuple([jnt.name for jnt in self.primary_joints]) # pylint: disable=not-an-iterable
+ primary_order = []
+ for nm in primary_names:
+ primary_order.append(joint_names.index(nm))
+ return primary_order
+
+ @composer.cached_property
+ def vertebra_joint_order(self):
+ joint_names = self.mocap_joint_order
+ vertebra_names = tuple([jnt.name for jnt in self.vertebra_joints]) # pylint: disable=not-an-iterable
+ vertebra_order = []
+ for nm in vertebra_names:
+ vertebra_order.append(joint_names.index(nm))
+ return vertebra_order
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ """Return the egocentric camera."""
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ @property
+ def _xml_path(self):
+ """Return the path to th model .xml file."""
+ return self.params['_XML_PATH']
+
+ @composer.cached_property
+ def joint_actuators(self):
+ """Return all joint actuators."""
+ return tuple([act for act in self._mjcf_root.find_all('actuator')
+ if act.joint])
+
+ @composer.cached_property
+ def joint_actuators_range(self):
+ act_joint_range = []
+ for act in self.joint_actuators: # This lint is mistaken; pylint: disable=not-an-iterable
+ associated_joint = self._mjcf_root.find('joint', act.name)
+ act_range = associated_joint.dclass.joint.range
+ act_joint_range.append(act_range)
+ return act_joint_range
+
+ def pose_to_actuation(self, pose):
+ # holds for joint actuators, find desired torque = 0
+ # u_ref = [2 q_ref - (r_low + r_up) ]/(r_up - r_low)
+ r_lower = np.array([ajr[0] for ajr in self.joint_actuators_range]) # This lint is mistaken; pylint: disable=not-an-iterable
+ r_upper = np.array([ajr[1] for ajr in self.joint_actuators_range]) # This lint is mistaken; pylint: disable=not-an-iterable
+ num_tendon_actuators = len(self.actuators) - len(self.joint_actuators)
+ tendon_actions = np.zeros(num_tendon_actuators)
+ return np.hstack([tendon_actions, (2*pose[self.joint_actuator_order]-
+ (r_lower+r_upper))/(r_upper-r_lower)])
+
+ @composer.cached_property
+ def joint_actuator_order(self):
+ joint_names = self.mocap_joint_order
+ joint_actuator_names = tuple([act.name for act in self.joint_actuators]) # This lint is mistaken; pylint: disable=not-an-iterable
+ actuator_order = []
+ for nm in joint_actuator_names:
+ actuator_order.append(joint_names.index(nm))
+ return actuator_order
+
+ def _build_observables(self):
+ return RodentObservables(self)
+
+
+class RodentObservables(legacy_base.WalkerObservables):
+ """Observables for the Rat."""
+
+ @composer.observable
+ def head_height(self):
+ """Observe the head height."""
+ return observable.MJCFFeature('xpos', self._entity.head)[2]
+
+ @composer.observable
+ def sensors_torque(self):
+ """Observe the torque sensors."""
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.sensor.torque,
+ corruptor=lambda v, random_state: np.tanh(2 * v / _TORQUE_THRESHOLD)
+ )
+
+ @composer.observable
+ def tendons_pos(self):
+ return observable.MJCFFeature('length', self._entity.observable_tendons)
+
+ @composer.observable
+ def tendons_vel(self):
+ return observable.MJCFFeature('velocity', self._entity.observable_tendons)
+
+ @composer.observable
+ def actuator_activation(self):
+ """Observe the actuator activation."""
+ model = self._entity.mjcf_model
+ return observable.MJCFFeature('act', model.find_all('actuator'))
+
+ @composer.observable
+ def appendages_pos(self):
+ """Equivalent to `end_effectors_pos` with head's position appended."""
+
+ def relative_pos_in_egocentric_frame(physics):
+ end_effectors_with_head = (
+ self._entity.end_effectors + (self._entity.head,))
+ end_effector = physics.bind(end_effectors_with_head).xpos
+ torso = physics.bind(self._entity.root_body).xpos
+ xmat = \
+ np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3))
+ return np.reshape(np.dot(end_effector - torso, xmat), -1)
+
+ return observable.Generic(relative_pos_in_egocentric_frame)
+
+ @property
+ def proprioception(self):
+ """Return proprioceptive information."""
+ return [
+ self.joints_pos, self.joints_vel,
+ self.tendons_pos, self.tendons_vel,
+ self.actuator_activation,
+ self.body_height, self.end_effectors_pos, self.appendages_pos,
+ self.world_zaxis
+ ] + self._collect_from_attachments('proprioception')
+
+ @composer.observable
+ def egocentric_camera(self):
+ """Observable of the egocentric camera."""
+
+ if not hasattr(self, '_scene_options'):
+ # Don't render this walker's geoms.
+ self._scene_options = mj_wrapper.MjvOption()
+ collision_geom_group = 2
+ self._scene_options.geomgroup[collision_geom_group] = 0
+ cosmetic_geom_group = 1
+ self._scene_options.geomgroup[cosmetic_geom_group] = 0
+
+ return observable.MJCFCamera(self._entity.egocentric_camera,
+ width=64, height=64,
+ scene_option=self._scene_options
+ )
diff --git a/dm_control/locomotion/walkers/rodent_test.py b/dm_control/locomotion/walkers/rodent_test.py
new file mode 100644
index 00000000..398ba897
--- /dev/null
+++ b/dm_control/locomotion/walkers/rodent_test.py
@@ -0,0 +1,123 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the Rodent."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base as observable_base
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.walkers import rodent
+import numpy as np
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+def _get_rat_corridor_physics():
+ walker = rodent.Rat()
+ arena = corr_arenas.EmptyCorridor()
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(5, 0, 0),
+ walker_spawn_rotation=0,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ env = composer.Environment(
+ time_limit=30,
+ task=task,
+ strip_singleton_obs_buffer_dim=True)
+
+ return walker, env
+
+
+class RatTest(parameterized.TestCase):
+
+ def test_can_compile_and_step_simulation(self):
+ _, env = _get_rat_corridor_physics()
+ physics = env.physics
+ for _ in range(100):
+ physics.step()
+
+ @parameterized.parameters([
+ 'egocentric_camera',
+ 'head',
+ 'left_arm_root',
+ 'right_arm_root',
+ 'root_body',
+ 'pelvis_body',
+ ])
+ def test_get_element_property(self, name):
+ attribute_value = getattr(rodent.Rat(), name)
+ self.assertIsInstance(attribute_value, mjcf.Element)
+
+ @parameterized.parameters([
+ 'actuators',
+ 'bodies',
+ 'mocap_tracking_bodies',
+ 'end_effectors',
+ 'mocap_joints',
+ 'observable_joints',
+ ])
+ def test_get_element_tuple_property(self, name):
+ attribute_value = getattr(rodent.Rat(), name)
+ self.assertNotEmpty(attribute_value)
+ for item in attribute_value:
+ self.assertIsInstance(item, mjcf.Element)
+
+ def test_set_name(self):
+ name = 'fred'
+ walker = rodent.Rat(name=name)
+ self.assertEqual(walker.mjcf_model.model, name)
+
+ @parameterized.parameters(
+ 'tendons_pos',
+ 'tendons_vel',
+ 'actuator_activation',
+ 'appendages_pos',
+ 'head_height',
+ 'sensors_torque',
+ )
+ def test_evaluate_observable(self, name):
+ walker, env = _get_rat_corridor_physics()
+ physics = env.physics
+ observable = getattr(walker.observables, name)
+ observation = observable(physics)
+ self.assertIsInstance(observation, (float, np.ndarray))
+
+ def test_proprioception(self):
+ walker = rodent.Rat()
+ for item in walker.observables.proprioception:
+ self.assertIsInstance(item, observable_base.Observable)
+
+ def test_can_create_two_rats(self):
+ rat1 = rodent.Rat(name='rat1')
+ rat2 = rodent.Rat(name='rat2')
+ arena = corr_arenas.EmptyCorridor()
+ arena.add_free_entity(rat1)
+ arena.add_free_entity(rat2)
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model) # Should not raise an error.
+
+ rat1.mjcf_model.model = 'rat3'
+ rat2.mjcf_model.model = 'rat4'
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model) # Should not raise an error.
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/walkers/scaled_actuators.py b/dm_control/locomotion/walkers/scaled_actuators.py
new file mode 100644
index 00000000..79db8da4
--- /dev/null
+++ b/dm_control/locomotion/walkers/scaled_actuators.py
@@ -0,0 +1,128 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Position & velocity actuators whose controls are scaled to a given range."""
+
+
+_DISALLOWED_KWARGS = frozenset(
+ ['biastype', 'gainprm', 'biasprm', 'ctrllimited',
+ 'joint', 'tendon', 'site', 'slidersite', 'cranksite'])
+_ALLOWED_TAGS = frozenset(['joint', 'tendon', 'site'])
+
+_GOT_INVALID_KWARGS = 'Received invalid keyword argument(s): {}'
+_GOT_INVALID_TARGET = '`target` tag type should be one of {}: got {{}}'.format(
+ sorted(_ALLOWED_TAGS))
+
+
+def _check_target_and_kwargs(target, **kwargs):
+ invalid_kwargs = _DISALLOWED_KWARGS.intersection(kwargs)
+ if invalid_kwargs:
+ raise TypeError(_GOT_INVALID_KWARGS.format(sorted(invalid_kwargs)))
+ if target.tag not in _ALLOWED_TAGS:
+ raise TypeError(_GOT_INVALID_TARGET.format(target))
+
+
+def add_position_actuator(target, qposrange, ctrlrange=(-1, 1),
+ kp=1.0, **kwargs):
+ """Adds a scaled position actuator that is bound to the specified element.
+
+ This is equivalent to MuJoCo's built-in `` actuator where an affine
+ transformation is pre-applied to the control signal, such that the minimum
+ control value corresponds to the minimum desired position, and the
+ maximum control value corresponds to the maximum desired position.
+
+ Args:
+ target: A PyMJCF joint, tendon, or site element object that is to be
+ controlled.
+ qposrange: A sequence of two numbers specifying the allowed range of target
+ position.
+ ctrlrange: A sequence of two numbers specifying the allowed range of
+ this actuator's control signal.
+ kp: The gain parameter of this position actuator.
+ **kwargs: Additional MJCF attributes for this actuator element.
+ The following attributes are disallowed: `['biastype', 'gainprm',
+ 'biasprm', 'ctrllimited', 'joint', 'tendon', 'site',
+ 'slidersite', 'cranksite']`.
+
+ Returns:
+ A PyMJCF actuator element that has been added to the MJCF model containing
+ the specified `target`.
+
+ Raises:
+ TypeError: `kwargs` contains an unrecognized or disallowed MJCF attribute,
+ or `target` is not an allowed MJCF element type.
+ """
+ _check_target_and_kwargs(target, **kwargs)
+ kwargs[target.tag] = target
+
+ slope = (qposrange[1] - qposrange[0]) / (ctrlrange[1] - ctrlrange[0])
+ g0 = kp * slope
+ b0 = kp * (qposrange[0] - slope * ctrlrange[0])
+ b1 = -kp
+ b2 = 0
+ return target.root.actuator.add('general',
+ biastype='affine',
+ gainprm=[g0],
+ biasprm=[b0, b1, b2],
+ ctrllimited=True,
+ ctrlrange=ctrlrange,
+ **kwargs)
+
+
+def add_velocity_actuator(target, qvelrange, ctrlrange=(-1, 1),
+ kv=1.0, **kwargs):
+ """Adds a scaled velocity actuator that is bound to the specified element.
+
+ This is equivalent to MuJoCo's built-in `` actuator where an affine
+ transformation is pre-applied to the control signal, such that the minimum
+ control value corresponds to the minimum desired velocity, and the
+ maximum control value corresponds to the maximum desired velocity.
+
+ Args:
+ target: A PyMJCF joint, tendon, or site element object that is to be
+ controlled.
+ qvelrange: A sequence of two numbers specifying the allowed range of target
+ velocity.
+ ctrlrange: A sequence of two numbers specifying the allowed range of
+ this actuator's control signal.
+ kv: The gain parameter of this velocity actuator.
+ **kwargs: Additional MJCF attributes for this actuator element.
+ The following attributes are disallowed: `['biastype', 'gainprm',
+ 'biasprm', 'ctrllimited', 'joint', 'tendon', 'site',
+ 'slidersite', 'cranksite']`.
+
+ Returns:
+ A PyMJCF actuator element that has been added to the MJCF model containing
+ the specified `target`.
+
+ Raises:
+ TypeError: `kwargs` contains an unrecognized or disallowed MJCF attribute,
+ or `target` is not an allowed MJCF element type.
+ """
+ _check_target_and_kwargs(target, **kwargs)
+ kwargs[target.tag] = target
+
+ slope = (qvelrange[1] - qvelrange[0]) / (ctrlrange[1] - ctrlrange[0])
+ g0 = kv * slope
+ b0 = kv * (qvelrange[0] - slope * ctrlrange[0])
+ b1 = 0
+ b2 = -kv
+ return target.root.actuator.add('general',
+ biastype='affine',
+ gainprm=[g0],
+ biasprm=[b0, b1, b2],
+ ctrllimited=True,
+ ctrlrange=ctrlrange,
+ **kwargs)
diff --git a/dm_control/locomotion/walkers/scaled_actuators_test.py b/dm_control/locomotion/walkers/scaled_actuators_test.py
new file mode 100644
index 00000000..0b155a34
--- /dev/null
+++ b/dm_control/locomotion/walkers/scaled_actuators_test.py
@@ -0,0 +1,131 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for scaled actuators."""
+
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.walkers import scaled_actuators
+import numpy as np
+
+
+class ScaledActuatorsTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._mjcf_model = mjcf.RootElement()
+ self._min = -1.4
+ self._max = 2.3
+ self._gain = 1.7
+ self._scaled_min = -0.8
+ self._scaled_max = 1.3
+ self._range = self._max - self._min
+ self._scaled_range = self._scaled_max - self._scaled_min
+ self._joints = []
+ for _ in range(2):
+ body = self._mjcf_model.worldbody.add('body')
+ body.add('geom', type='sphere', size=[1])
+ self._joints.append(body.add('joint', type='hinge'))
+ self._scaled_actuator_joint = self._joints[0]
+ self._standard_actuator_joint = self._joints[1]
+ self._random_state = np.random.RandomState(3474)
+
+ def _set_actuator_controls(self, physics, normalized_ctrl,
+ scaled_actuator=None, standard_actuator=None):
+ if scaled_actuator is not None:
+ physics.bind(scaled_actuator).ctrl = (
+ normalized_ctrl * self._scaled_range + self._scaled_min)
+ if standard_actuator is not None:
+ physics.bind(standard_actuator).ctrl = (
+ normalized_ctrl * self._range + self._min)
+
+ def _assert_same_qfrc_actuator(self, physics, joint1, joint2):
+ np.testing.assert_allclose(physics.bind(joint1).qfrc_actuator,
+ physics.bind(joint2).qfrc_actuator)
+
+ def test_position_actuator(self):
+ scaled_actuator = scaled_actuators.add_position_actuator(
+ target=self._scaled_actuator_joint, kp=self._gain,
+ qposrange=(self._min, self._max),
+ ctrlrange=(self._scaled_min, self._scaled_max))
+ standard_actuator = self._mjcf_model.actuator.add(
+ 'position', joint=self._standard_actuator_joint, kp=self._gain,
+ ctrllimited=True, ctrlrange=(self._min, self._max))
+ physics = mjcf.Physics.from_mjcf_model(self._mjcf_model)
+
+ # Zero torque.
+ physics.bind(self._scaled_actuator_joint).qpos = (
+ 0.2345 * self._range + self._min)
+ self._set_actuator_controls(physics, 0.2345, scaled_actuator)
+ np.testing.assert_allclose(
+ physics.bind(self._scaled_actuator_joint).qfrc_actuator, 0, atol=1e-15)
+
+ for _ in range(100):
+ normalized_ctrl = self._random_state.uniform()
+ physics.bind(self._joints).qpos = (
+ self._random_state.uniform() * self._range + self._min)
+ self._set_actuator_controls(physics, normalized_ctrl,
+ scaled_actuator, standard_actuator)
+ self._assert_same_qfrc_actuator(
+ physics, self._scaled_actuator_joint, self._standard_actuator_joint)
+
+ def test_velocity_actuator(self):
+ scaled_actuator = scaled_actuators.add_velocity_actuator(
+ target=self._scaled_actuator_joint, kv=self._gain,
+ qvelrange=(self._min, self._max),
+ ctrlrange=(self._scaled_min, self._scaled_max))
+ standard_actuator = self._mjcf_model.actuator.add(
+ 'velocity', joint=self._standard_actuator_joint, kv=self._gain,
+ ctrllimited=True, ctrlrange=(self._min, self._max))
+ physics = mjcf.Physics.from_mjcf_model(self._mjcf_model)
+
+ # Zero torque.
+ physics.bind(self._scaled_actuator_joint).qvel = (
+ 0.5432 * self._range + self._min)
+ self._set_actuator_controls(physics, 0.5432, scaled_actuator)
+ np.testing.assert_allclose(
+ physics.bind(self._scaled_actuator_joint).qfrc_actuator, 0, atol=1e-15)
+
+ for _ in range(100):
+ normalized_ctrl = self._random_state.uniform()
+ physics.bind(self._joints).qvel = (
+ self._random_state.uniform() * self._range + self._min)
+ self._set_actuator_controls(physics, normalized_ctrl,
+ scaled_actuator, standard_actuator)
+ self._assert_same_qfrc_actuator(
+ physics, self._scaled_actuator_joint, self._standard_actuator_joint)
+
+ def test_invalid_kwargs(self):
+ invalid_kwargs = dict(joint=self._scaled_actuator_joint, ctrllimited=False)
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ scaled_actuators._GOT_INVALID_KWARGS.format(sorted(invalid_kwargs))):
+ scaled_actuators.add_position_actuator(
+ target=self._scaled_actuator_joint,
+ qposrange=(self._min, self._max),
+ **invalid_kwargs)
+
+ def test_invalid_target(self):
+ invalid_target = self._mjcf_model.worldbody
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ scaled_actuators._GOT_INVALID_TARGET.format(invalid_target)):
+ scaled_actuators.add_position_actuator(
+ target=invalid_target, qposrange=(self._min, self._max))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/locomotion/walls.png b/dm_control/locomotion/walls.png
new file mode 100644
index 00000000..fc153862
Binary files /dev/null and b/dm_control/locomotion/walls.png differ
diff --git a/dm_control/manipulation/__init__.py b/dm_control/manipulation/__init__.py
new file mode 100644
index 00000000..dcfc9eea
--- /dev/null
+++ b/dm_control/manipulation/__init__.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A structured set of manipulation tasks with a single entry point."""
+
+
+from absl import flags
+from dm_control import composer as _composer
+from dm_control.manipulation import bricks as _bricks
+from dm_control.manipulation import lift as _lift
+from dm_control.manipulation import place as _place
+from dm_control.manipulation import reach as _reach
+from dm_control.manipulation.shared import registry as _registry
+
+_registry.done_importing_tasks()
+
+_TIME_LIMIT = 10.
+_TIMEOUT = None
+
+ALL = tuple(_registry.get_all_names())
+TAGS = tuple(_registry.get_tags())
+
+flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
+FLAGS = flags.FLAGS
+
+
+def _get_timeout():
+ global _TIMEOUT
+ if _TIMEOUT is None:
+ if FLAGS.is_parsed():
+ _TIMEOUT = FLAGS.timeout
+ else:
+ _TIMEOUT = FLAGS['timeout'].default
+ return _TIMEOUT
+
+
+def get_environments_by_tag(tag):
+ """Returns the names of all environments matching a given tag.
+
+ Args:
+ tag: A string from `TAGS`.
+
+ Returns:
+ A tuple of environment names.
+ """
+ return tuple(_registry.get_names_by_tag(tag))
+
+
+def load(environment_name, seed=None):
+ """Loads a manipulation environment.
+
+ Args:
+ environment_name: String, the name of the environment to load. Must be in
+ `ALL`.
+ seed: An optional integer used to seed the task's random number generator.
+ If None (default), the random number generator will self-seed from a
+ platform-dependent source of entropy.
+
+ Returns:
+ An instance of `composer.Environment`.
+ """
+ task = _registry.get_constructor(environment_name)()
+ time_limit = _TIME_LIMIT if _get_timeout() else float('inf')
+ return _composer.Environment(task, time_limit=time_limit, random_state=seed)
diff --git a/dm_control/manipulation/bricks.py b/dm_control/manipulation/bricks.py
new file mode 100644
index 00000000..92fbf4ca
--- /dev/null
+++ b/dm_control/manipulation/bricks.py
@@ -0,0 +1,710 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tasks involving assembly and/or disassembly of bricks."""
+
+import collections
+
+from absl import logging
+from dm_control import composer
+from dm_control.composer import initializers
+from dm_control.composer import variation
+from dm_control.composer.observation import observable
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.manipulation.shared import arenas
+from dm_control.manipulation.shared import cameras
+from dm_control.manipulation.shared import constants
+from dm_control.manipulation.shared import observations
+from dm_control.manipulation.shared import registry
+from dm_control.manipulation.shared import robots
+from dm_control.manipulation.shared import tags
+from dm_control.manipulation.shared import workspaces
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.utils import rewards
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+
+_BrickWorkspace = collections.namedtuple(
+ '_BrickWorkspace',
+ ['prop_bbox', 'tcp_bbox', 'goal_hint_pos', 'goal_hint_quat', 'arm_offset'])
+
+# Ensures that the prop does not collide with the table during initialization.
+_PROP_Z_OFFSET = 1e-6
+
+_WORKSPACE = _BrickWorkspace(
+ prop_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, _PROP_Z_OFFSET),
+ upper=(0.1, 0.1, _PROP_Z_OFFSET)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, 0.15),
+ upper=(0.1, 0.1, 0.4)),
+ goal_hint_pos=(0.2, 0.1, 0.),
+ goal_hint_quat=(-0.38268343, 0., 0., 0.92387953),
+ arm_offset=robots.ARM_OFFSET)
+
+# Alpha value of the visual goal hint representing the goal state for each task.
+_HINT_ALPHA = 0.75
+
+# Distance thresholds for the shaping rewards for getting the top brick close
+# to the bottom brick, and for 'clicking' them together.
+_CLOSE_THRESHOLD = 0.01
+_CLICK_THRESHOLD = 0.001
+
+# Sequence of colors for the brick(s).
+_COLOR_VALUES, _COLOR_NAMES = list(
+ zip(
+ ((1., 0., 0.), 'red'),
+ ((0., 1., 0.), 'green'),
+ ((0., 0., 1.), 'blue'),
+ ((0., 1., 1.), 'cyan'),
+ ((1., 0., 1.), 'magenta'),
+ ((1., 1., 0.), 'yellow'),
+ ))
+
+
+class _Common(composer.Task):
+ """Common components of brick tasks."""
+
+ def __init__(self,
+ arena,
+ arm,
+ hand,
+ num_bricks,
+ obs_settings,
+ workspace,
+ control_timestep):
+ if not 2 <= num_bricks <= 6:
+ raise ValueError('`num_bricks` must be between 2 and 6, got {}.'
+ .format(num_bricks))
+
+ if num_bricks > 3:
+ # The default values computed by MuJoCo's compiler are too small if there
+ # are more than three stacked bricks, since each stacked pair generates
+ # a large number of contacts. The values below are sufficient for up to
+ # 6 stacked bricks.
+ # TODO(b/78331644): It may be useful to log the size of `physics.model`
+ # and `physics.data` after compilation to gauge the
+ # impact of these changes on MuJoCo's memory footprint.
+ arena.mjcf_model.size.nconmax = 400
+ arena.mjcf_model.size.njmax = 1200
+
+ self._arena = arena
+ self._arm = arm
+ self._hand = hand
+ self._arm.attach(self._hand)
+ self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
+ self.control_timestep = control_timestep
+
+ # Add custom camera observable.
+ self._task_observables = cameras.add_camera_observables(
+ arena, obs_settings, cameras.FRONT_CLOSE)
+
+ color_sequence = iter(_COLOR_VALUES)
+ brick_obs_options = observations.make_options(
+ obs_settings, observations.FREEPROP_OBSERVABLES)
+
+ bricks = []
+ brick_frames = []
+ goal_hint_bricks = []
+ for _ in range(num_bricks):
+ color = next(color_sequence)
+ brick = props.Duplo(color=color,
+ observable_options=brick_obs_options)
+ brick_frames.append(arena.add_free_entity(brick))
+ bricks.append(brick)
+
+ # Translucent, contactless brick with no observables. These are used to
+ # provide a visual hint representing the goal state for each task.
+ hint_brick = props.Duplo(color=color)
+ _hintify(hint_brick, alpha=_HINT_ALPHA)
+ arena.attach(hint_brick)
+ goal_hint_bricks.append(hint_brick)
+
+ self._bricks = bricks
+ self._brick_frames = brick_frames
+ self._goal_hint_bricks = goal_hint_bricks
+
+ # Position and quaternion for the goal hint.
+ self._goal_hint_pos = workspace.goal_hint_pos
+ self._goal_hint_quat = workspace.goal_hint_quat
+
+ self._tcp_initializer = initializers.ToolCenterPointInitializer(
+ self._hand, self._arm,
+ position=distributions.Uniform(*workspace.tcp_bbox),
+ quaternion=workspaces.DOWN_QUATERNION)
+
+ # Add sites for visual debugging.
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.tcp_bbox.lower,
+ upper=workspace.tcp_bbox.upper,
+ rgba=constants.GREEN, name='tcp_spawn_area')
+
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.prop_bbox.lower,
+ upper=workspace.prop_bbox.upper,
+ rgba=constants.BLUE, name='prop_spawn_area')
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ @property
+ def arm(self):
+ return self._arm
+
+ @property
+ def hand(self):
+ return self._hand
+
+
+class Stack(_Common):
+ """Build a stack of Duplo bricks."""
+
+ def __init__(self,
+ arena,
+ arm,
+ hand,
+ num_bricks,
+ target_height,
+ moveable_base,
+ randomize_order,
+ obs_settings,
+ workspace,
+ control_timestep):
+ """Initializes a new `Stack` task.
+
+ Args:
+ arena: `composer.Entity` instance.
+ arm: `robot_base.RobotArm` instance.
+ hand: `robot_base.RobotHand` instance.
+ num_bricks: The total number of bricks; must be between 2 and 6.
+ target_height: The target number of bricks in the stack in order to get
+ maximum reward. Must be between 2 and `num_bricks`.
+ moveable_base: Boolean specifying whether or not the bottom brick should
+ be moveable.
+ randomize_order: Boolean specifying whether to randomize the desired order
+ of bricks in the stack at the start of each episode.
+ obs_settings: `observations.ObservationSettings` instance.
+ workspace: A `_BrickWorkspace` instance.
+ control_timestep: Float specifying the control timestep in seconds.
+
+ Raises:
+ ValueError: If `num_bricks` is not between 2 and 6, or if
+ `target_height` is not between 2 and `num_bricks - 1`.
+ """
+ if not 2 <= target_height <= num_bricks:
+ raise ValueError('`target_height` must be between 2 and {}, got {}.'
+ .format(num_bricks, target_height))
+
+ super().__init__(
+ arena=arena,
+ arm=arm,
+ hand=hand,
+ num_bricks=num_bricks,
+ obs_settings=obs_settings,
+ workspace=workspace,
+ control_timestep=control_timestep)
+
+ self._moveable_base = moveable_base
+ self._randomize_order = randomize_order
+ self._target_height = target_height
+ self._prop_bbox = workspace.prop_bbox
+
+ # Shuffled at the start of each episode if `randomize_order` is True.
+ self._desired_order = np.arange(target_height)
+
+ # In the random order case, create a `prop_pose` observable that informs the
+ # agent of the desired order.
+ if randomize_order:
+ desired_order_observable = observable.Generic(self._get_desired_order)
+ desired_order_observable.configure(**obs_settings.prop_pose._asdict())
+ self._task_observables['desired_order'] = desired_order_observable
+
+ def _get_desired_order(self, physics):
+ del physics # Unused
+ return self._desired_order.astype(np.double)
+
+ def initialize_episode_mjcf(self, random_state):
+ if self._randomize_order:
+ self._desired_order = random_state.choice(
+ len(self._bricks), size=self._target_height, replace=False)
+ logging.info('Desired stack order (from bottom to top): [%s]',
+ ' '.join(_COLOR_NAMES[i] for i in self._desired_order))
+
+ # If the base of the stack should be fixed, remove the freejoint for the
+ # first brick (and ensure that all the others have freejoints).
+ fixed_indices = [] if self._moveable_base else [self._desired_order[0]]
+ _add_or_remove_freejoints(attachment_frames=self._brick_frames,
+ fixed_indices=fixed_indices)
+
+ # We need to define the prop initializer for the bricks here rather than in
+ # the `__init__`, since `PropPlacer` looks for freejoints on instantiation.
+ self._brick_placer = initializers.PropPlacer(
+ props=self._bricks,
+ position=distributions.Uniform(*self._prop_bbox),
+ quaternion=workspaces.uniform_z_rotation,
+ settle_physics=True)
+
+ def initialize_episode(self, physics, random_state):
+ self._brick_placer(physics, random_state)
+ self._hand.set_grasp(physics, close_factors=random_state.uniform())
+ self._tcp_initializer(physics, random_state)
+ # Arrange the goal hint bricks in the desired stack order.
+ _build_stack(physics,
+ bricks=self._goal_hint_bricks,
+ base_pos=self._goal_hint_pos,
+ base_quat=self._goal_hint_quat,
+ order=self._desired_order,
+ random_state=random_state)
+
+ def get_reward(self, physics):
+ pairs = list(zip(self._desired_order[:-1], self._desired_order[1:]))
+ pairwise_rewards = _get_pairwise_stacking_rewards(
+ physics=physics, bricks=self._bricks, pairs=pairs)
+ # The final reward is an average over the pairwise rewards.
+ return np.mean(pairwise_rewards)
+
+
+class Reassemble(_Common):
+ """Disassemble a stack of Duplo bricks and reassemble it in another order."""
+
+ def __init__(self,
+ arena,
+ arm,
+ hand,
+ num_bricks,
+ randomize_initial_order,
+ randomize_desired_order,
+ obs_settings,
+ workspace,
+ control_timestep):
+ """Initializes a new `Reassemble` task.
+
+ Args:
+ arena: `composer.Entity` instance.
+ arm: `robot_base.RobotArm` instance.
+ hand: `robot_base.RobotHand` instance.
+ num_bricks: The total number of bricks; must be between 2 and 6.
+ randomize_initial_order: Boolean specifying whether to randomize the
+ initial order of bricks in the stack at the start of each episode.
+ randomize_desired_order: Boolean specifying whether to independently
+ randomize the desired order of bricks in the stack at the start of each
+ episode. By default the desired order will be the reverse of the initial
+ order, with the exception of the base brick which is always the same as
+ in the initial order since it is welded in place.
+ obs_settings: `observations.ObservationSettings` instance.
+ workspace: A `_BrickWorkspace` instance.
+ control_timestep: Float specifying the control timestep in seconds.
+
+ Raises:
+ ValueError: If `num_bricks` is not between 2 and 6.
+ """
+ super().__init__(
+ arena=arena,
+ arm=arm,
+ hand=hand,
+ num_bricks=num_bricks,
+ obs_settings=obs_settings,
+ workspace=workspace,
+ control_timestep=control_timestep)
+ self._randomize_initial_order = randomize_initial_order
+ self._randomize_desired_order = randomize_desired_order
+
+ # Randomized at the start of each episode if `randomize_initial_order` is
+ # True.
+ self._initial_order = np.arange(num_bricks)
+
+ # Randomized at the start of each episode if `randomize_desired_order` is
+ # True.
+ self._desired_order = self._initial_order.copy()
+ self._desired_order[1:] = self._desired_order[-1:0:-1]
+
+ # In the random order case, create a `prop_pose` observable that informs the
+ # agent of the desired order.
+ if randomize_desired_order:
+ desired_order_observable = observable.Generic(self._get_desired_order)
+ desired_order_observable.configure(**obs_settings.prop_pose._asdict())
+ self._task_observables['desired_order'] = desired_order_observable
+
+ # Distributions of positions and orientations for the base of the stack.
+ self._base_pos = distributions.Uniform(*workspace.prop_bbox)
+ self._base_quat = workspaces.uniform_z_rotation
+
+ def _get_desired_order(self, physics):
+ del physics # Unused
+ return self._desired_order.astype(np.double)
+
+ def initialize_episode_mjcf(self, random_state):
+ if self._randomize_initial_order:
+ random_state.shuffle(self._initial_order)
+
+ # The bottom brick will be fixed to the table, so it must be the same in
+ # both the initial and desired order.
+ self._desired_order[0] = self._initial_order[0]
+ # By default the desired order of the other bricks is the opposite of their
+ # initial order.
+ self._desired_order[1:] = self._initial_order[-1:0:-1]
+
+ if self._randomize_desired_order:
+ random_state.shuffle(self._desired_order[1:])
+
+ logging.info('Desired stack order (from bottom to top): [%s]',
+ ' '.join(_COLOR_NAMES[i] for i in self._desired_order))
+
+ # Remove the freejoint from the bottom brick in the stack.
+ _add_or_remove_freejoints(attachment_frames=self._brick_frames,
+ fixed_indices=[self._initial_order[0]])
+
+ def initialize_episode(self, physics, random_state):
+ # Build the initial stack.
+ _build_stack(physics,
+ bricks=self._bricks,
+ base_pos=self._base_pos,
+ base_quat=self._base_quat,
+ order=self._initial_order,
+ random_state=random_state)
+ # Arrange the goal hint bricks into a stack with the desired order.
+ _build_stack(physics,
+ bricks=self._goal_hint_bricks,
+ base_pos=self._goal_hint_pos,
+ base_quat=self._goal_hint_quat,
+ order=self._desired_order,
+ random_state=random_state)
+ self._hand.set_grasp(physics, close_factors=random_state.uniform())
+ self._tcp_initializer(physics, random_state)
+
+ def get_reward(self, physics):
+ pairs = list(zip(self._desired_order[:-1], self._desired_order[1:]))
+ # We set `close_coef=0.` because the coarse shaping reward causes problems
+ # for this task (it means there is a strong disincentive to break up the
+ # initial stack).
+ pairwise_rewards = _get_pairwise_stacking_rewards(
+ physics=physics,
+ bricks=self._bricks,
+ pairs=pairs,
+ close_coef=0.)
+ # The final reward is an average over the pairwise rewards.
+ return np.mean(pairwise_rewards)
+
+
+def _distance(pos1, pos2):
+ diff = pos1 - pos2
+ return sum(np.sqrt((diff * diff).sum(1)))
+
+
+def _min_stud_to_hole_distance(physics, bottom_brick, top_brick):
+ # Positions of the top left and bottom right studs on the `bottom_brick` and
+ # the top left and bottom right holes on the `top_brick`.
+ stud_pos = physics.bind(bottom_brick.studs[[0, -1], [0, -1]]).xpos
+ hole_pos = physics.bind(top_brick.holes[[0, -1], [0, -1]]).xpos
+ # Bricks are rotationally symmetric, so we compute top left -> top left and
+ # top left -> bottom right distances and return whichever of these is smaller.
+ dist1 = _distance(stud_pos, hole_pos)
+ dist2 = _distance(stud_pos[::-1], hole_pos)
+ return min(dist1, dist2)
+
+
+def _get_pairwise_stacking_rewards(physics, bricks, pairs, close_coef=0.1):
+ """Returns a vector of shaping reward components based on pairwise distances.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ bricks: A list of `composer.Entity` instances corresponding to bricks.
+ pairs: A list of `(bottom_idx, top_idx)` tuples specifying which pairs of
+ bricks should be measured.
+ close_coef: Float specfying the relative weight given to the coarse-
+ tolerance shaping component for getting the bricks close to one another
+ (as opposed to the fine-tolerance component for clicking them together).
+
+ Returns:
+ A numpy array of size `len(pairs)` containing values in (0, 1], where
+ 1 corresponds to a stacked pair of bricks.
+ """
+ distances = []
+ for bottom_idx, top_idx in pairs:
+ bottom_brick = bricks[bottom_idx]
+ top_brick = bricks[top_idx]
+ distances.append(
+ _min_stud_to_hole_distance(physics, bottom_brick, top_brick))
+ distances = np.hstack(distances)
+
+ # Coarse-tolerance component for bringing the holes close to the studs.
+ close = rewards.tolerance(
+ distances, bounds=(0, _CLOSE_THRESHOLD), margin=(_CLOSE_THRESHOLD * 10))
+
+ # Fine-tolerance component for clicking the bricks together.
+ clicked = rewards.tolerance(
+ distances, bounds=(0, _CLICK_THRESHOLD), margin=_CLICK_THRESHOLD)
+
+ # Weighted average of coarse and fine components for each pair of bricks.
+ return np.average([close, clicked], weights=[close_coef, 1.], axis=0)
+
+
+def _build_stack(physics, bricks, base_pos, base_quat, order, random_state):
+ """Builds a stack of bricks.
+
+ Args:
+ physics: Instance of `mjcf.Physics`.
+ bricks: Sequence of `composer.Entity` instances corresponding to bricks.
+ base_pos: Position of the base brick in the stack; either a (3,) numpy array
+ or a `variation.Variation` that yields such arrays.
+ base_quat: Quaternion of the base brick in the stack; either a (4,) numpy
+ array or a `variation.Variation` that yields such arrays.
+ order: Sequence of indices specifying the order in which to stack the
+ bricks.
+ random_state: An `np.random.RandomState` instance.
+ """
+ base_pos = variation.evaluate(base_pos, random_state=random_state)
+ base_quat = variation.evaluate(base_quat, random_state=random_state)
+ bricks[order[0]].set_pose(physics, position=base_pos, quaternion=base_quat)
+ for bottom_idx, top_idx in zip(order[:-1], order[1:]):
+ bottom = bricks[bottom_idx]
+ top = bricks[top_idx]
+ stud_pos = physics.bind(bottom.studs[0, 0]).xpos
+ _, quat = bottom.get_pose(physics)
+ # The reward function treats top left -> top left and top left -> bottom
+ # right configurations as identical, so the orientations of the bricks are
+ # randomized so that 50% of the time the top brick is rotated 180 degrees
+ # relative to the brick below.
+ if random_state.rand() < 0.5:
+ quat = quat.copy()
+ axis = np.array([0., 0., 1.])
+ angle = np.pi
+ mjlib.mju_quatIntegrate(quat, axis, angle)
+ hole_idx = (-1, -1)
+ else:
+ hole_idx = (0, 0)
+ top.set_pose(physics, quaternion=quat)
+
+ # Set the position of the top brick so that its holes line up with the studs
+ # of the brick below.
+ offset = physics.bind(top.holes[hole_idx]).xpos
+ top_pos = stud_pos - offset
+ top.set_pose(physics, position=top_pos)
+
+
+def _add_or_remove_freejoints(attachment_frames, fixed_indices):
+ """Adds or removes freejoints from props.
+
+ Args:
+ attachment_frames: A list of `mjcf.Elements` corresponding to attachment
+ frames.
+ fixed_indices: A list of indices of attachment frames that should be fixed
+ to the world (i.e. have their freejoints removed). Freejoints will be
+ added to all other elements in `attachment_frames` if they do not already
+ possess them.
+ """
+ for i, frame in enumerate(attachment_frames):
+ if i in fixed_indices:
+ if frame.freejoint:
+ frame.freejoint.remove()
+ elif not frame.freejoint:
+ frame.add('freejoint')
+
+
+def _replace_alpha(rgba, alpha=0.3):
+ new_rgba = rgba.copy()
+ new_rgba[3] = alpha
+ return new_rgba
+
+
+def _hintify(entity, alpha=None):
+ """Modifies an entity for use as a 'visual hint'.
+
+ Contacts will be disabled for all geoms within the entity, and its bodies will
+ be converted to "mocap" bodies (which are viewed as fixed from the perspective
+ of the dynamics). The geom alpha values may also be overridden to render the
+ geoms as translucent.
+
+ Args:
+ entity: A `composer.Entity`, modified in place.
+ alpha: Optional float between 0 and 1, used to override the alpha values for
+ all of the geoms in this entity.
+ """
+ for subentity in entity.iter_entities():
+ # TODO(b/112084359): This assumes that all geoms either define explicit RGBA
+ # values, or inherit from the top-level default. It will
+ # not correctly handle more complicated hierarchies of
+ # default classes.
+ if (alpha is not None
+ and subentity.mjcf_model.default.geom is not None
+ and subentity.mjcf_model.default.geom.rgba is not None):
+ subentity.mjcf_model.default.geom.rgba = _replace_alpha(
+ subentity.mjcf_model.default.geom.rgba, alpha=alpha)
+ for body in subentity.mjcf_model.find_all('body'):
+ body.mocap = 'true'
+ for geom in subentity.mjcf_model.find_all('geom'):
+ if alpha is not None and geom.rgba is not None:
+ geom.rgba = _replace_alpha(geom.rgba, alpha=alpha)
+ geom.contype = 0
+ geom.conaffinity = 0
+
+
+def _stack(obs_settings, num_bricks, moveable_base, randomize_order,
+ target_height=None):
+ """Configure and instantiate a Stack task.
+
+ Args:
+ obs_settings: `observations.ObservationSettings` instance.
+ num_bricks: The total number of bricks; must be between 2 and 6.
+ moveable_base: Boolean specifying whether or not the bottom brick should
+ be moveable.
+ randomize_order: Boolean specifying whether to randomize the desired order
+ of bricks in the stack at the start of each episode.
+ target_height: The target number of bricks in the stack in order to get
+ maximum reward. Must be between 2 and `num_bricks`. Defaults to
+ `num_bricks`.
+
+ Returns:
+ An instance of `Stack`.
+ """
+ if target_height is None:
+ target_height = num_bricks
+ arena = arenas.Standard()
+ arm = robots.make_arm(obs_settings=obs_settings)
+ hand = robots.make_hand(obs_settings=obs_settings)
+ return Stack(arena=arena,
+ arm=arm,
+ hand=hand,
+ num_bricks=num_bricks,
+ target_height=target_height,
+ moveable_base=moveable_base,
+ randomize_order=randomize_order,
+ obs_settings=obs_settings,
+ workspace=_WORKSPACE,
+ control_timestep=constants.CONTROL_TIMESTEP)
+
+
+@registry.add(tags.FEATURES)
+def stack_2_bricks_features():
+ return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=2,
+ moveable_base=False, randomize_order=False)
+
+
+@registry.add(tags.VISION)
+def stack_2_bricks_vision():
+ return _stack(obs_settings=observations.VISION, num_bricks=2,
+ moveable_base=False, randomize_order=False)
+
+
+@registry.add(tags.FEATURES)
+def stack_2_bricks_moveable_base_features():
+ return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=2,
+ moveable_base=True, randomize_order=False)
+
+
+@registry.add(tags.VISION)
+def stack_2_bricks_moveable_base_vision():
+ return _stack(obs_settings=observations.VISION, num_bricks=2,
+ moveable_base=True, randomize_order=False)
+
+
+@registry.add(tags.FEATURES)
+def stack_3_bricks_features():
+ return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
+ moveable_base=False, randomize_order=False)
+
+
+@registry.add(tags.VISION)
+def stack_3_bricks_vision():
+ return _stack(obs_settings=observations.VISION, num_bricks=3,
+ moveable_base=False, randomize_order=False)
+
+
+@registry.add(tags.FEATURES)
+def stack_3_bricks_random_order_features():
+ return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
+ moveable_base=False, randomize_order=True)
+
+
+@registry.add(tags.FEATURES)
+def stack_2_of_3_bricks_random_order_features():
+ return _stack(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
+ moveable_base=False, randomize_order=True, target_height=2)
+
+
+@registry.add(tags.VISION)
+def stack_2_of_3_bricks_random_order_vision():
+ return _stack(obs_settings=observations.VISION, num_bricks=3,
+ moveable_base=False, randomize_order=True, target_height=2)
+
+
+def _reassemble(obs_settings, num_bricks, randomize_initial_order,
+ randomize_desired_order):
+ """Configure and instantiate a `Reassemble` task.
+
+ Args:
+ obs_settings: `observations.ObservationSettings` instance.
+ num_bricks: The total number of bricks; must be between 2 and 6.
+ randomize_initial_order: Boolean specifying whether to randomize the
+ initial order of bricks in the stack at the start of each episode.
+ randomize_desired_order: Boolean specifying whether to independently
+ randomize the desired order of bricks in the stack at the start of each
+ episode. By default the desired order will be the reverse of the initial
+ order, with the exception of the base brick which is always the same as
+ in the initial order since it is welded in place.
+
+ Returns:
+ An instance of `Reassemble`.
+ """
+ arena = arenas.Standard()
+ arm = robots.make_arm(obs_settings=obs_settings)
+ hand = robots.make_hand(obs_settings=obs_settings)
+ return Reassemble(arena=arena,
+ arm=arm,
+ hand=hand,
+ num_bricks=num_bricks,
+ randomize_initial_order=randomize_initial_order,
+ randomize_desired_order=randomize_desired_order,
+ obs_settings=obs_settings,
+ workspace=_WORKSPACE,
+ control_timestep=constants.CONTROL_TIMESTEP)
+
+
+@registry.add(tags.FEATURES)
+def reassemble_3_bricks_fixed_order_features():
+ return _reassemble(obs_settings=observations.PERFECT_FEATURES, num_bricks=3,
+ randomize_initial_order=False,
+ randomize_desired_order=False)
+
+
+@registry.add(tags.VISION)
+def reassemble_3_bricks_fixed_order_vision():
+ return _reassemble(obs_settings=observations.VISION, num_bricks=3,
+ randomize_initial_order=False,
+ randomize_desired_order=False)
+
+
+@registry.add(tags.FEATURES)
+def reassemble_5_bricks_random_order_features():
+ return _reassemble(obs_settings=observations.PERFECT_FEATURES, num_bricks=5,
+ randomize_initial_order=True,
+ randomize_desired_order=True)
+
+
+@registry.add(tags.VISION)
+def reassemble_5_bricks_random_order_vision():
+ return _reassemble(obs_settings=observations.VISION, num_bricks=5,
+ randomize_initial_order=True,
+ randomize_desired_order=True)
diff --git a/dm_control/manipulation/explore.py b/dm_control/manipulation/explore.py
new file mode 100644
index 00000000..0dce76c4
--- /dev/null
+++ b/dm_control/manipulation/explore.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A standalone application for visualizing manipulation tasks."""
+
+import functools
+
+from absl import app
+from absl import flags
+from dm_control import manipulation
+
+from dm_control import viewer
+
+flags.DEFINE_enum(
+ 'environment_name', None, manipulation.ALL,
+ 'Optional name of an environment to load. If unspecified '
+ 'a prompt will appear to select one.')
+FLAGS = flags.FLAGS
+
+
+# TODO(b/121187817): Consolidate with dm_control/suite/explore.py
+def prompt_environment_name(prompt, values):
+ environment_name = None
+ while not environment_name:
+ environment_name = input(prompt)
+ if not environment_name or values.index(environment_name) < 0:
+ print('"%s" is not a valid environment name.' % environment_name)
+ environment_name = None
+ return environment_name
+
+
+def main(argv):
+ del argv
+ environment_name = FLAGS.environment_name
+
+ all_names = list(manipulation.ALL)
+
+ if environment_name is None:
+ print('\n '.join(['Available environments:'] + all_names))
+ environment_name = prompt_environment_name(
+ 'Please select an environment name: ', all_names)
+
+ loader = functools.partial(
+ manipulation.load, environment_name=environment_name)
+ viewer.launch(loader)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/manipulation/lift.py b/dm_control/manipulation/lift.py
new file mode 100644
index 00000000..70637eaa
--- /dev/null
+++ b/dm_control/manipulation/lift.py
@@ -0,0 +1,250 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tasks where the goal is to elevate a prop."""
+
+import collections
+import itertools
+
+from dm_control import composer
+from dm_control.composer import initializers
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.manipulation.shared import arenas
+from dm_control.manipulation.shared import cameras
+from dm_control.manipulation.shared import constants
+from dm_control.manipulation.shared import observations
+from dm_control.manipulation.shared import registry
+from dm_control.manipulation.shared import robots
+from dm_control.manipulation.shared import tags
+from dm_control.manipulation.shared import workspaces
+from dm_control.utils import rewards
+import numpy as np
+
+
+_LiftWorkspace = collections.namedtuple(
+ '_LiftWorkspace', ['prop_bbox', 'tcp_bbox', 'arm_offset'])
+
+_DUPLO_WORKSPACE = _LiftWorkspace(
+ prop_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, 0.0),
+ upper=(0.1, 0.1, 0.0)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, 0.2),
+ upper=(0.1, 0.1, 0.4)),
+ arm_offset=robots.ARM_OFFSET)
+
+_BOX_SIZE = 0.09
+_BOX_MASS = 1.3
+_BOX_WORKSPACE = _LiftWorkspace(
+ prop_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, _BOX_SIZE),
+ upper=(0.1, 0.1, _BOX_SIZE)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, 0.2),
+ upper=(0.1, 0.1, 0.4)),
+ arm_offset=robots.ARM_OFFSET)
+
+_DISTANCE_TO_LIFT = 0.3
+
+
+class _VertexSitesMixin:
+ """Mixin class that adds sites corresponding to the vertices of a box."""
+
+ def _add_vertex_sites(self, box_geom_or_site):
+ """Add sites corresponding to the vertices of a box geom or site."""
+ offsets = (
+ (-half_length, half_length) for half_length in box_geom_or_site.size)
+ site_positions = np.vstack(list(itertools.product(*offsets)))
+ if box_geom_or_site.pos is not None:
+ site_positions += box_geom_or_site.pos
+ self._vertices = []
+ for i, pos in enumerate(site_positions):
+ site = box_geom_or_site.parent.add(
+ 'site', name='vertex_' + str(i), pos=pos, type='sphere', size=[0.002],
+ rgba=constants.RED, group=constants.TASK_SITE_GROUP)
+ self._vertices.append(site)
+
+ @property
+ def vertices(self):
+ return self._vertices
+
+
+class _BoxWithVertexSites(props.Primitive, _VertexSitesMixin):
+ """Subclass of `Box` with sites marking the vertices of the box geom."""
+
+ def _build(self, *args, **kwargs):
+ super()._build(*args, geom_type='box', **kwargs)
+ self._add_vertex_sites(self.geom)
+
+
+class _DuploWithVertexSites(props.Duplo, _VertexSitesMixin):
+ """Subclass of `Duplo` with sites marking the vertices of its sensor site."""
+
+ def _build(self, *args, **kwargs):
+ super()._build(*args, **kwargs)
+ self._add_vertex_sites(self.mjcf_model.find('site', 'bounding_box'))
+
+
+class Lift(composer.Task):
+ """A task where the goal is to elevate a prop."""
+
+ def __init__(
+ self, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
+ """Initializes a new `Lift` task.
+
+ Args:
+ arena: `composer.Entity` instance.
+ arm: `robot_base.RobotArm` instance.
+ hand: `robot_base.RobotHand` instance.
+ prop: `composer.Entity` instance.
+ obs_settings: `observations.ObservationSettings` instance.
+ workspace: `_LiftWorkspace` specifying the placement of the prop and TCP.
+ control_timestep: Float specifying the control timestep in seconds.
+ """
+ self._arena = arena
+ self._arm = arm
+ self._hand = hand
+ self._arm.attach(self._hand)
+ self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
+ self.control_timestep = control_timestep
+
+ # Add custom camera observable.
+ self._task_observables = cameras.add_camera_observables(
+ arena, obs_settings, cameras.FRONT_CLOSE)
+
+ self._tcp_initializer = initializers.ToolCenterPointInitializer(
+ self._hand, self._arm,
+ position=distributions.Uniform(*workspace.tcp_bbox),
+ quaternion=workspaces.DOWN_QUATERNION)
+
+ self._prop = prop
+ self._arena.add_free_entity(prop)
+ self._prop_placer = initializers.PropPlacer(
+ props=[prop],
+ position=distributions.Uniform(*workspace.prop_bbox),
+ quaternion=workspaces.uniform_z_rotation,
+ ignore_collisions=True,
+ settle_physics=True)
+
+ # Add sites for visualizing bounding boxes and target height.
+ self._target_height_site = workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=(-1, -1, 0), upper=(1, 1, 0),
+ rgba=constants.RED, name='target_height')
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
+ rgba=constants.GREEN, name='tcp_spawn_area')
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.prop_bbox.lower, upper=workspace.prop_bbox.upper,
+ rgba=constants.BLUE, name='prop_spawn_area')
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ @property
+ def arm(self):
+ return self._arm
+
+ @property
+ def hand(self):
+ return self._hand
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ def _get_height_of_lowest_vertex(self, physics):
+ return min(physics.bind(self._prop.vertices).xpos[:, 2])
+
+ def get_reward(self, physics):
+ prop_height = self._get_height_of_lowest_vertex(physics)
+ return rewards.tolerance(prop_height,
+ bounds=(self._target_height, np.inf),
+ margin=_DISTANCE_TO_LIFT,
+ value_at_margin=0,
+ sigmoid='linear')
+
+ def initialize_episode(self, physics, random_state):
+ self._hand.set_grasp(physics, close_factors=random_state.uniform())
+ self._prop_placer(physics, random_state)
+ self._tcp_initializer(physics, random_state)
+ # Compute the target height based on the initial height of the prop's
+ # center of mass after settling.
+ initial_prop_height = self._get_height_of_lowest_vertex(physics)
+ self._target_height = _DISTANCE_TO_LIFT + initial_prop_height
+ physics.bind(self._target_height_site).pos[2] = self._target_height
+
+
+def _lift(obs_settings, prop_name):
+ """Configure and instantiate a Lift task.
+
+ Args:
+ obs_settings: `observations.ObservationSettings` instance.
+ prop_name: The name of the prop to be lifted. Must be either 'duplo' or
+ 'box'.
+
+ Returns:
+ An instance of `lift.Lift`.
+
+ Raises:
+ ValueError: If `prop_name` is neither 'duplo' nor 'box'.
+ """
+ arena = arenas.Standard()
+ arm = robots.make_arm(obs_settings=obs_settings)
+ hand = robots.make_hand(obs_settings=obs_settings)
+
+ if prop_name == 'duplo':
+ workspace = _DUPLO_WORKSPACE
+ prop = _DuploWithVertexSites(
+ observable_options=observations.make_options(
+ obs_settings, observations.FREEPROP_OBSERVABLES))
+ elif prop_name == 'box':
+ workspace = _BOX_WORKSPACE
+ # NB: The box is intentionally too large to pick up with a pinch grip.
+ prop = _BoxWithVertexSites(
+ size=[_BOX_SIZE] * 3,
+ observable_options=observations.make_options(
+ obs_settings, observations.FREEPROP_OBSERVABLES))
+ prop.geom.mass = _BOX_MASS
+ else:
+ raise ValueError('`prop_name` must be either \'duplo\' or \'box\'.')
+ task = Lift(arena=arena, arm=arm, hand=hand, prop=prop, workspace=workspace,
+ obs_settings=obs_settings,
+ control_timestep=constants.CONTROL_TIMESTEP)
+ return task
+
+
+@registry.add(tags.FEATURES)
+def lift_brick_features():
+ return _lift(obs_settings=observations.PERFECT_FEATURES, prop_name='duplo')
+
+
+@registry.add(tags.VISION)
+def lift_brick_vision():
+ return _lift(obs_settings=observations.VISION, prop_name='duplo')
+
+
+@registry.add(tags.FEATURES)
+def lift_large_box_features():
+ return _lift(obs_settings=observations.PERFECT_FEATURES, prop_name='box')
+
+
+@registry.add(tags.VISION)
+def lift_large_box_vision():
+ return _lift(obs_settings=observations.VISION, prop_name='box')
diff --git a/dm_control/manipulation/manipulation_test.py b/dm_control/manipulation/manipulation_test.py
new file mode 100644
index 00000000..c0e9c1f0
--- /dev/null
+++ b/dm_control/manipulation/manipulation_test.py
@@ -0,0 +1,92 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for `dm_control.manipulation_suite`."""
+
+from absl import flags
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import manipulation
+import numpy as np
+
+
+flags.DEFINE_boolean(
+ 'fix_seed', True,
+ 'Whether to fix the seed for the environment\'s random number generator. '
+ 'This the default since it prevents non-deterministic failures, but it may '
+ 'be useful to allow the seed to vary in some cases, for example when '
+ 'repeating a test many times in order to detect rare failure events.')
+
+FLAGS = flags.FLAGS
+
+_FIX_SEED = None
+_NUM_EPISODES = 5
+_NUM_STEPS_PER_EPISODE = 10
+
+
+def _get_fix_seed():
+ global _FIX_SEED
+ if _FIX_SEED is None:
+ if FLAGS.is_parsed():
+ _FIX_SEED = FLAGS.fix_seed
+ else:
+ _FIX_SEED = FLAGS['fix_seed'].default
+ return _FIX_SEED
+
+
+class ManipulationTest(parameterized.TestCase):
+ """Tests run on all the tasks registered."""
+
+ def _validate_observation(self, observation, observation_spec):
+ self.assertEqual(list(observation.keys()), list(observation_spec.keys()))
+ for name, array_spec in observation_spec.items():
+ array_spec.validate(observation[name])
+
+ def _validate_reward_range(self, reward):
+ self.assertIsInstance(reward, float)
+ self.assertBetween(reward, 0, 1)
+
+ def _validate_discount(self, discount):
+ self.assertIsInstance(discount, float)
+ self.assertBetween(discount, 0, 1)
+
+ @parameterized.parameters(*manipulation.ALL)
+ def test_task_runs(self, task_name):
+ """Tests that the environment runs and is coherent with its specs."""
+ seed = 99 if _get_fix_seed() else None
+ env = manipulation.load(task_name, seed=seed)
+ random_state = np.random.RandomState(seed)
+
+ observation_spec = env.observation_spec()
+ action_spec = env.action_spec()
+ self.assertTrue(np.all(np.isfinite(action_spec.minimum)))
+ self.assertTrue(np.all(np.isfinite(action_spec.maximum)))
+
+ # Run a partial episode, check observations, rewards, discount.
+ for _ in range(_NUM_EPISODES):
+ time_step = env.reset()
+ for _ in range(_NUM_STEPS_PER_EPISODE):
+ self._validate_observation(time_step.observation, observation_spec)
+ if time_step.first():
+ self.assertIsNone(time_step.reward)
+ self.assertIsNone(time_step.discount)
+ else:
+ self._validate_reward_range(time_step.reward)
+ self._validate_discount(time_step.discount)
+ action = random_state.uniform(action_spec.minimum, action_spec.maximum)
+ env.step(action)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/manipulation/place.py b/dm_control/manipulation/place.py
new file mode 100644
index 00000000..07fd774e
--- /dev/null
+++ b/dm_control/manipulation/place.py
@@ -0,0 +1,293 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A task where the goal is to place a movable prop on top of a fixed prop."""
+
+import collections
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer import initializers
+from dm_control.composer.observation import observable
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.manipulation.shared import arenas
+from dm_control.manipulation.shared import cameras
+from dm_control.manipulation.shared import constants
+from dm_control.manipulation.shared import observations
+from dm_control.manipulation.shared import registry
+from dm_control.manipulation.shared import robots
+from dm_control.manipulation.shared import tags
+from dm_control.manipulation.shared import workspaces
+from dm_control.utils import rewards
+import numpy as np
+
+
+_PlaceWorkspace = collections.namedtuple(
+ '_PlaceWorkspace', ['prop_bbox', 'target_bbox', 'tcp_bbox', 'arm_offset'])
+
+_TARGET_RADIUS = 0.05
+_PEDESTAL_RADIUS = 0.07
+
+# Ensures that the prop does not collide with the table during initialization.
+_PROP_Z_OFFSET = 1e-6
+
+_WORKSPACE = _PlaceWorkspace(
+ prop_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, _PROP_Z_OFFSET),
+ upper=(0.1, 0.1, _PROP_Z_OFFSET)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, _PEDESTAL_RADIUS + 0.1),
+ upper=(0.1, 0.1, 0.4)),
+ target_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, _PEDESTAL_RADIUS),
+ upper=(0.1, 0.1, _PEDESTAL_RADIUS + 0.1)),
+ arm_offset=robots.ARM_OFFSET)
+
+
+class SphereCradle(composer.Entity):
+ """A concave shape for easy placement."""
+ _SPHERE_COUNT = 3
+
+ def _build(self):
+ self._mjcf_root = mjcf.element.RootElement(model='cradle')
+ sphere_radius = _PEDESTAL_RADIUS * 0.7
+ for ang in np.linspace(0, 2*np.pi, num=self._SPHERE_COUNT, endpoint=False):
+ pos = 0.7 * sphere_radius * np.array([np.sin(ang), np.cos(ang), -1])
+ self._mjcf_root.worldbody.add(
+ 'geom', type='sphere', size=[sphere_radius], condim=6, pos=pos)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+
+class Pedestal(composer.Entity):
+ """A narrow pillar to elevate the target."""
+ _HEIGHT = 0.2
+
+ def _build(self, cradle, target_radius):
+ self._mjcf_root = mjcf.element.RootElement(model='pedestal')
+
+ self._mjcf_root.worldbody.add(
+ 'geom', type='capsule', size=[_PEDESTAL_RADIUS],
+ fromto=[0, 0, -_PEDESTAL_RADIUS,
+ 0, 0, -(self._HEIGHT + _PEDESTAL_RADIUS)])
+ attachment_site = self._mjcf_root.worldbody.add(
+ 'site', type='sphere', size=(0.003,), group=constants.TASK_SITE_GROUP)
+ self.attach(cradle, attachment_site)
+ self._target_site = workspaces.add_target_site(
+ body=self.mjcf_model.worldbody,
+ radius=target_radius, rgba=constants.RED)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def target_site(self):
+ return self._target_site
+
+ def _build_observables(self):
+ return PedestalObservables(self)
+
+
+class PedestalObservables(composer.Observables):
+ """Observables for the `Pedestal` prop."""
+
+ @define.observable
+ def position(self):
+ return observable.MJCFFeature('xpos', self._entity.target_site)
+
+
+class Place(composer.Task):
+ """Place the prop on top of another fixed prop held up by a pedestal."""
+
+ def __init__(self, arena, arm, hand, prop, obs_settings, workspace,
+ control_timestep, cradle):
+ """Initializes a new `Place` task.
+
+ Args:
+ arena: `composer.Entity` instance.
+ arm: `robot_base.RobotArm` instance.
+ hand: `robot_base.RobotHand` instance.
+ prop: `composer.Entity` instance.
+ obs_settings: `observations.ObservationSettings` instance.
+ workspace: A `_PlaceWorkspace` instance.
+ control_timestep: Float specifying the control timestep in seconds.
+ cradle: `composer.Entity` onto which the `prop` must be placed.
+ """
+ self._arena = arena
+ self._arm = arm
+ self._hand = hand
+ self._arm.attach(self._hand)
+ self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
+ self.control_timestep = control_timestep
+
+ # Add custom camera observable.
+ self._task_observables = cameras.add_camera_observables(
+ arena, obs_settings, cameras.FRONT_CLOSE)
+
+ self._tcp_initializer = initializers.ToolCenterPointInitializer(
+ self._hand, self._arm,
+ position=distributions.Uniform(*workspace.tcp_bbox),
+ quaternion=workspaces.DOWN_QUATERNION)
+
+ self._prop = prop
+ self._prop_frame = self._arena.add_free_entity(prop)
+ self._pedestal = Pedestal(cradle=cradle, target_radius=_TARGET_RADIUS)
+ self._arena.attach(self._pedestal)
+
+ for obs in self._pedestal.observables.as_dict().values():
+ obs.configure(**obs_settings.prop_pose._asdict())
+
+ self._prop_placer = initializers.PropPlacer(
+ props=[prop],
+ position=distributions.Uniform(*workspace.prop_bbox),
+ quaternion=workspaces.uniform_z_rotation,
+ settle_physics=True,
+ max_attempts_per_prop=50)
+
+ self._pedestal_placer = initializers.PropPlacer(
+ props=[self._pedestal],
+ position=distributions.Uniform(*workspace.target_bbox),
+ settle_physics=False)
+
+ # Add sites for visual debugging.
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.tcp_bbox.lower,
+ upper=workspace.tcp_bbox.upper,
+ rgba=constants.GREEN, name='tcp_spawn_area')
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.prop_bbox.lower,
+ upper=workspace.prop_bbox.upper,
+ rgba=constants.BLUE, name='prop_spawn_area')
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.target_bbox.lower,
+ upper=workspace.target_bbox.upper,
+ rgba=constants.CYAN, name='pedestal_spawn_area')
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ @property
+ def arm(self):
+ return self._arm
+
+ @property
+ def hand(self):
+ return self._hand
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ def initialize_episode(self, physics, random_state):
+ self._pedestal_placer(physics, random_state,
+ ignore_contacts_with_entities=[self._prop])
+ self._hand.set_grasp(physics, close_factors=random_state.uniform())
+ self._tcp_initializer(physics, random_state)
+ self._prop_placer(physics, random_state)
+
+ def get_reward(self, physics):
+ target = physics.bind(self._pedestal.target_site).xpos
+ obj = physics.bind(self._prop_frame).xpos
+ tcp = physics.bind(self._hand.tool_center_point).xpos
+
+ tcp_to_obj = np.linalg.norm(obj - tcp)
+ grasp = rewards.tolerance(tcp_to_obj,
+ bounds=(0, _TARGET_RADIUS),
+ margin=_TARGET_RADIUS,
+ sigmoid='long_tail')
+
+ obj_to_target = np.linalg.norm(obj - target)
+ in_place = rewards.tolerance(obj_to_target,
+ bounds=(0, _TARGET_RADIUS),
+ margin=_TARGET_RADIUS,
+ sigmoid='long_tail')
+
+ tcp_to_target = np.linalg.norm(tcp - target)
+ hand_away = rewards.tolerance(tcp_to_target,
+ bounds=(4*_TARGET_RADIUS, np.inf),
+ margin=3*_TARGET_RADIUS,
+ sigmoid='long_tail')
+ in_place_weight = 10.
+ grasp_or_hand_away = grasp * (1 - in_place) + hand_away * in_place
+ return (
+ grasp_or_hand_away + in_place_weight * in_place) / (1 + in_place_weight)
+
+
+def _place(obs_settings, cradle_prop_name):
+ """Configure and instantiate a Place task.
+
+ Args:
+ obs_settings: `observations.ObservationSettings` instance.
+ cradle_prop_name: The name of the prop onto which the Duplo brick must be
+ placed. Must be either 'duplo' or 'cradle'.
+
+ Returns:
+ An instance of `Place`.
+
+ Raises:
+ ValueError: If `prop_name` is neither 'duplo' nor 'cradle'.
+ """
+ arena = arenas.Standard()
+ arm = robots.make_arm(obs_settings=obs_settings)
+ hand = robots.make_hand(obs_settings=obs_settings)
+
+ prop = props.Duplo(
+ observable_options=observations.make_options(
+ obs_settings, observations.FREEPROP_OBSERVABLES))
+ if cradle_prop_name == 'duplo':
+ cradle = props.Duplo()
+ elif cradle_prop_name == 'cradle':
+ cradle = SphereCradle()
+ else:
+ raise ValueError(
+ '`cradle_prop_name` must be either \'duplo\' or \'cradle\'.')
+
+ task = Place(arena=arena, arm=arm, hand=hand, prop=prop,
+ obs_settings=obs_settings,
+ workspace=_WORKSPACE,
+ control_timestep=constants.CONTROL_TIMESTEP,
+ cradle=cradle)
+ return task
+
+
+@registry.add(tags.FEATURES)
+def place_brick_features():
+ return _place(obs_settings=observations.PERFECT_FEATURES,
+ cradle_prop_name='duplo')
+
+
+@registry.add(tags.VISION)
+def place_brick_vision():
+ return _place(obs_settings=observations.VISION, cradle_prop_name='duplo')
+
+
+@registry.add(tags.FEATURES)
+def place_cradle_features():
+ return _place(obs_settings=observations.PERFECT_FEATURES,
+ cradle_prop_name='cradle')
+
+
+@registry.add(tags.VISION)
+def place_cradle_vision():
+ return _place(obs_settings=observations.VISION, cradle_prop_name='cradle')
diff --git a/dm_control/manipulation/props/__init__.py b/dm_control/manipulation/props/__init__.py
new file mode 100644
index 00000000..7d0b9099
--- /dev/null
+++ b/dm_control/manipulation/props/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2021 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Props for manipulation tasks."""
+from dm_control.manipulation.props.primitive import Box
+from dm_control.manipulation.props.primitive import BoxWithSites
+from dm_control.manipulation.props.primitive import Capsule
+from dm_control.manipulation.props.primitive import Cylinder
+from dm_control.manipulation.props.primitive import Ellipsoid
+from dm_control.manipulation.props.primitive import Primitive
+from dm_control.manipulation.props.primitive import Sphere
diff --git a/dm_control/manipulation/props/primitive.py b/dm_control/manipulation/props/primitive.py
new file mode 100644
index 00000000..92fa6421
--- /dev/null
+++ b/dm_control/manipulation/props/primitive.py
@@ -0,0 +1,212 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Props made of a single primitive MuJoCo geom."""
+import itertools
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer.observation import observable
+import numpy as np
+_DEFAULT_HALF_LENGTHS = [0.05, 0.1, 0.15]
+
+
+class Primitive(composer.Entity):
+ """A primitive MuJoCo geom prop."""
+
+ def _build(self, geom_type, size, mass=None, name=None):
+ """Initializes this prop.
+
+ Args:
+ geom_type: a string, one of the types supported by MuJoCo.
+ size: a list or numpy array of up to 3 numbers, depending on the type.
+ mass: The mass for the primitive geom.
+ name: (optional) A string, the name of this prop.
+ """
+ size = np.reshape(np.asarray(size), -1)
+ self._mjcf_root = mjcf.element.RootElement(model=name)
+
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', name='body_geom', type=geom_type, size=size, mass=mass)
+
+ touch_sensor = self._mjcf_root.worldbody.add(
+ 'site', type=geom_type, name='touch_sensor', size=size*1.05,
+ rgba=[1, 1, 1, 0.1], # touch sensor site is almost transparent
+ group=composer.SENSOR_SITES_GROUP)
+
+ self._touch = self._mjcf_root.sensor.add(
+ 'touch', site=touch_sensor)
+
+ self._position = self._mjcf_root.sensor.add(
+ 'framepos', name='position', objtype='geom', objname=self.geom)
+
+ self._orientation = self._mjcf_root.sensor.add(
+ 'framequat', name='orientation', objtype='geom',
+ objname=self.geom)
+
+ self._linear_velocity = self._mjcf_root.sensor.add(
+ 'framelinvel', name='linear_velocity', objtype='geom',
+ objname=self.geom)
+
+ self._angular_velocity = self._mjcf_root.sensor.add(
+ 'frameangvel', name='angular_velocity', objtype='geom',
+ objname=self.geom)
+
+ self._name = name
+
+ def _build_observables(self):
+ return PrimitiveObservables(self)
+
+ @property
+ def geom(self):
+ """Returns the primitive's geom, e.g., to change color or friction."""
+ return self._geom
+
+ @property
+ def touch(self):
+ """Exposing the touch sensor for observations and reward."""
+ return self._touch
+
+ @property
+ def position(self):
+ """Ground truth pos sensor."""
+ return self._position
+
+ @property
+ def orientation(self):
+ """Ground truth angular position sensor."""
+ return self._orientation
+
+ @property
+ def linear_velocity(self):
+ """Ground truth velocity sensor."""
+ return self._linear_velocity
+
+ @property
+ def angular_velocity(self):
+ """Ground truth angular velocity sensor."""
+ return self._angular_velocity
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def name(self):
+ return self._name
+
+
+class PrimitiveObservables(composer.Observables,
+ composer.FreePropObservableMixin):
+ """Primitive entity's observables."""
+
+ @define.observable
+ def position(self):
+ return observable.MJCFFeature('sensordata', self._entity.position)
+
+ @define.observable
+ def orientation(self):
+ return observable.MJCFFeature('sensordata', self._entity.orientation)
+
+ @define.observable
+ def linear_velocity(self):
+ return observable.MJCFFeature('sensordata', self._entity.linear_velocity)
+
+ @define.observable
+ def angular_velocity(self):
+ return observable.MJCFFeature('sensordata', self._entity.angular_velocity)
+
+ @define.observable
+ def touch(self):
+ return observable.MJCFFeature('sensordata', self._entity.touch)
+
+
+class Sphere(Primitive):
+ """A class representing a sphere prop."""
+
+ def _build(self, radius=0.05, mass=None, name='sphere'):
+ super(Sphere, self)._build(
+ geom_type='sphere', size=radius, mass=mass, name=name)
+
+
+class Box(Primitive):
+ """A class representing a box prop."""
+
+ def _build(self, half_lengths=None, mass=None, name='box'):
+ half_lengths = half_lengths or _DEFAULT_HALF_LENGTHS
+ super(Box, self)._build(geom_type='box',
+ size=half_lengths,
+ mass=mass,
+ name=name)
+
+
+class BoxWithSites(Box):
+ """A class representing a box prop with sites on the corners."""
+
+ def _build(self, half_lengths=None, mass=None, name='box'):
+ half_lengths = half_lengths or _DEFAULT_HALF_LENGTHS
+ super(BoxWithSites, self)._build(half_lengths=half_lengths, mass=mass,
+ name=name)
+
+ corner_positions = itertools.product([half_lengths[0], -half_lengths[0]],
+ [half_lengths[1], -half_lengths[1]],
+ [half_lengths[2], -half_lengths[2]])
+ corner_sites = []
+ for i, corner_pos in enumerate(corner_positions):
+ corner_sites.append(
+ self._mjcf_root.worldbody.add(
+ 'site',
+ type='sphere',
+ name='corner_{}'.format(i),
+ size=[0.1],
+ pos=corner_pos,
+ rgba=[1, 0, 0, 1.0],
+ group=composer.SENSOR_SITES_GROUP))
+ self._corner_sites = tuple(corner_sites)
+
+ @property
+ def corner_sites(self):
+ return self._corner_sites
+
+
+class Ellipsoid(Primitive):
+ """A class representing an ellipsoid prop."""
+
+ def _build(self, radii=None, mass=None, name='ellipsoid'):
+ radii = radii or _DEFAULT_HALF_LENGTHS
+ super(Ellipsoid, self)._build(geom_type='ellipsoid',
+ size=radii,
+ mass=mass,
+ name=name)
+
+
+class Cylinder(Primitive):
+ """A class representing a cylinder prop."""
+
+ def _build(self, radius=0.05, half_length=0.15, mass=None, name='cylinder'):
+ super(Cylinder, self)._build(geom_type='cylinder',
+ size=[radius, half_length],
+ mass=mass,
+ name=name)
+
+
+class Capsule(Primitive):
+ """A class representing a capsule prop."""
+
+ def _build(self, radius=0.05, half_length=0.15, mass=None, name='capsule'):
+ super(Capsule, self)._build(geom_type='capsule',
+ size=[radius, half_length],
+ mass=mass,
+ name=name)
diff --git a/dm_control/manipulation/reach.py b/dm_control/manipulation/reach.py
new file mode 100644
index 00000000..b2c8d4d2
--- /dev/null
+++ b/dm_control/manipulation/reach.py
@@ -0,0 +1,210 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A task where the goal is to move the hand close to a target prop or site."""
+
+import collections
+
+from dm_control import composer
+from dm_control.composer import initializers
+from dm_control.composer.observation import observable
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.manipulation.shared import arenas
+from dm_control.manipulation.shared import cameras
+from dm_control.manipulation.shared import constants
+from dm_control.manipulation.shared import observations
+from dm_control.manipulation.shared import registry
+from dm_control.manipulation.shared import robots
+from dm_control.manipulation.shared import tags
+from dm_control.manipulation.shared import workspaces
+from dm_control.utils import rewards
+import numpy as np
+
+
+_ReachWorkspace = collections.namedtuple(
+ '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
+
+# Ensures that the props are not touching the table before settling.
+_PROP_Z_OFFSET = 0.001
+
+_DUPLO_WORKSPACE = _ReachWorkspace(
+ target_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, _PROP_Z_OFFSET),
+ upper=(0.1, 0.1, _PROP_Z_OFFSET)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, 0.2),
+ upper=(0.1, 0.1, 0.4)),
+ arm_offset=robots.ARM_OFFSET)
+
+_SITE_WORKSPACE = _ReachWorkspace(
+ target_bbox=workspaces.BoundingBox(
+ lower=(-0.2, -0.2, 0.02),
+ upper=(0.2, 0.2, 0.4)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.2, -0.2, 0.02),
+ upper=(0.2, 0.2, 0.4)),
+ arm_offset=robots.ARM_OFFSET)
+
+_TARGET_RADIUS = 0.05
+
+
+class Reach(composer.Task):
+ """Bring the hand close to a target prop or site."""
+
+ def __init__(
+ self, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
+ """Initializes a new `Reach` task.
+
+ Args:
+ arena: `composer.Entity` instance.
+ arm: `robot_base.RobotArm` instance.
+ hand: `robot_base.RobotHand` instance.
+ prop: `composer.Entity` instance specifying the prop to reach to, or None
+ in which case the target is a fixed site whose position is specified by
+ the workspace.
+ obs_settings: `observations.ObservationSettings` instance.
+ workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
+ control_timestep: Float specifying the control timestep in seconds.
+ """
+ self._arena = arena
+ self._arm = arm
+ self._hand = hand
+ self._arm.attach(self._hand)
+ self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
+ self.control_timestep = control_timestep
+ self._tcp_initializer = initializers.ToolCenterPointInitializer(
+ self._hand, self._arm,
+ position=distributions.Uniform(*workspace.tcp_bbox),
+ quaternion=workspaces.DOWN_QUATERNION)
+
+ # Add custom camera observable.
+ self._task_observables = cameras.add_camera_observables(
+ arena, obs_settings, cameras.FRONT_CLOSE)
+
+ target_pos_distribution = distributions.Uniform(*workspace.target_bbox)
+ self._prop = prop
+ if prop:
+ # The prop itself is used to visualize the target location.
+ self._make_target_site(parent_entity=prop, visible=False)
+ self._target = self._arena.add_free_entity(prop)
+ self._prop_placer = initializers.PropPlacer(
+ props=[prop],
+ position=target_pos_distribution,
+ quaternion=workspaces.uniform_z_rotation,
+ settle_physics=True)
+ else:
+ self._target = self._make_target_site(parent_entity=arena, visible=True)
+ self._target_placer = target_pos_distribution
+
+ obs = observable.MJCFFeature('pos', self._target)
+ obs.configure(**obs_settings.prop_pose._asdict())
+ self._task_observables['target_position'] = obs
+
+ # Add sites for visualizing the prop and target bounding boxes.
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
+ rgba=constants.GREEN, name='tcp_spawn_area')
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper,
+ rgba=constants.BLUE, name='target_spawn_area')
+
+ def _make_target_site(self, parent_entity, visible):
+ return workspaces.add_target_site(
+ body=parent_entity.mjcf_model.worldbody,
+ radius=_TARGET_RADIUS, visible=visible,
+ rgba=constants.RED, name='target_site')
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ @property
+ def arm(self):
+ return self._arm
+
+ @property
+ def hand(self):
+ return self._hand
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ def get_reward(self, physics):
+ hand_pos = physics.bind(self._hand.tool_center_point).xpos
+ target_pos = physics.bind(self._target).xpos
+ distance = np.linalg.norm(hand_pos - target_pos)
+ return rewards.tolerance(
+ distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS)
+
+ def initialize_episode(self, physics, random_state):
+ self._hand.set_grasp(physics, close_factors=random_state.uniform())
+ self._tcp_initializer(physics, random_state)
+ if self._prop:
+ self._prop_placer(physics, random_state)
+ else:
+ physics.bind(self._target).pos = (
+ self._target_placer(random_state=random_state))
+
+
+def _reach(obs_settings, use_site):
+ """Configure and instantiate a `Reach` task.
+
+ Args:
+ obs_settings: An `observations.ObservationSettings` instance.
+ use_site: Boolean, if True then the target will be a fixed site, otherwise
+ it will be a moveable Duplo brick.
+
+ Returns:
+ An instance of `reach.Reach`.
+ """
+ arena = arenas.Standard()
+ arm = robots.make_arm(obs_settings=obs_settings)
+ hand = robots.make_hand(obs_settings=obs_settings)
+ if use_site:
+ workspace = _SITE_WORKSPACE
+ prop = None
+ else:
+ workspace = _DUPLO_WORKSPACE
+ prop = props.Duplo(observable_options=observations.make_options(
+ obs_settings, observations.FREEPROP_OBSERVABLES))
+ task = Reach(arena=arena, arm=arm, hand=hand, prop=prop,
+ obs_settings=obs_settings,
+ workspace=workspace,
+ control_timestep=constants.CONTROL_TIMESTEP)
+ return task
+
+
+@registry.add(tags.FEATURES, tags.EASY)
+def reach_duplo_features():
+ return _reach(obs_settings=observations.PERFECT_FEATURES, use_site=False)
+
+
+@registry.add(tags.VISION, tags.EASY)
+def reach_duplo_vision():
+ return _reach(obs_settings=observations.VISION, use_site=False)
+
+
+@registry.add(tags.FEATURES, tags.EASY)
+def reach_site_features():
+ return _reach(obs_settings=observations.PERFECT_FEATURES, use_site=True)
+
+
+@registry.add(tags.VISION, tags.EASY)
+def reach_site_vision():
+ return _reach(obs_settings=observations.VISION, use_site=True)
diff --git a/dm_control/manipulation/shared/__init__.py b/dm_control/manipulation/shared/__init__.py
new file mode 100644
index 00000000..a514c4bb
--- /dev/null
+++ b/dm_control/manipulation/shared/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/manipulation/shared/arenas.py b/dm_control/manipulation/shared/arenas.py
new file mode 100644
index 00000000..2245ed07
--- /dev/null
+++ b/dm_control/manipulation/shared/arenas.py
@@ -0,0 +1,99 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Suite-specific arena class."""
+
+
+from dm_control import composer
+
+
+class Standard(composer.Arena):
+ """Suite-specific subclass of the standard Composer arena."""
+
+ def _build(self, name=None):
+ """Initializes this arena.
+
+ Args:
+ name: (optional) A string, the name of this arena. If `None`, use the
+ model name defined in the MJCF file.
+ """
+ super()._build(name=name)
+
+ # Add visual assets.
+ self.mjcf_model.asset.add(
+ 'texture',
+ type='skybox',
+ builtin='gradient',
+ rgb1=(0.4, 0.6, 0.8),
+ rgb2=(0., 0., 0.),
+ width=100,
+ height=100)
+ groundplane_texture = self.mjcf_model.asset.add(
+ 'texture',
+ name='groundplane',
+ type='2d',
+ builtin='checker',
+ rgb1=(0.2, 0.3, 0.4),
+ rgb2=(0.1, 0.2, 0.3),
+ width=300,
+ height=300,
+ mark='edge',
+ markrgb=(.8, .8, .8))
+ groundplane_material = self.mjcf_model.asset.add(
+ 'material',
+ name='groundplane',
+ texture=groundplane_texture,
+ texrepeat=(5, 5),
+ texuniform='true',
+ reflectance=0.2)
+
+ # Add ground plane.
+ self.mjcf_model.worldbody.add(
+ 'geom',
+ name='ground',
+ type='plane',
+ material=groundplane_material,
+ size=(1, 1, 0.1),
+ friction=(0.4,),
+ solimp=(0.95, 0.99, 0.001),
+ solref=(0.002, 1))
+
+ # Add lighting
+ self.mjcf_model.worldbody.add(
+ 'light',
+ pos=(0, 0, 1.5),
+ dir=(0, 0, -1),
+ diffuse=(0.7, 0.7, 0.7),
+ specular=(.3, .3, .3),
+ directional='false',
+ castshadow='true')
+
+ # Always initialize the free camera so that it points at the origin.
+ self.mjcf_model.statistic.center = (0., 0., 0.)
+
+ def attach_offset(self, entity, offset, attach_site=None):
+ """Attaches another entity at a position offset from the attachment site.
+
+ Args:
+ entity: The `Entity` to attach.
+ offset: A length 3 array-like object representing the XYZ offset.
+ attach_site: (optional) The site to which to attach the entity's model.
+ If not set, defaults to self.attachment_site.
+ Returns:
+ The frame of the attached model.
+ """
+ frame = self.attach(entity, attach_site=attach_site)
+ frame.pos = offset
+ return frame
diff --git a/dm_control/manipulation/shared/cameras.py b/dm_control/manipulation/shared/cameras.py
new file mode 100644
index 00000000..729f7e8d
--- /dev/null
+++ b/dm_control/manipulation/shared/cameras.py
@@ -0,0 +1,75 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tools for adding custom cameras to the arena."""
+
+import collections
+
+from dm_control.composer.observation import observable
+
+
+CameraSpec = collections.namedtuple('CameraSpec', ['name', 'pos', 'xyaxes'])
+
+# Custom cameras that may be added to the arena for particular tasks.
+FRONT_CLOSE = CameraSpec(
+ name='front_close',
+ pos=(0., -0.6, 0.75),
+ xyaxes=(1., 0., 0., 0., 0.7, 0.75)
+)
+
+FRONT_FAR = CameraSpec(
+ name='front_far',
+ pos=(0., -0.8, 1.),
+ xyaxes=(1., 0., 0., 0., 0.7, 0.75)
+)
+
+TOP_DOWN = CameraSpec(
+ name='top_down',
+ pos=(0., 0., 2.5),
+ xyaxes=(1., 0., 0., 0., 1., 0.)
+)
+
+LEFT_CLOSE = CameraSpec(
+ name='left_close',
+ pos=(-0.6, 0., 0.75),
+ xyaxes=(0., -1., 0., 0.7, 0., 0.75)
+)
+
+RIGHT_CLOSE = CameraSpec(
+ name='right_close',
+ pos=(0.6, 0., 0.75),
+ xyaxes=(0., 1., 0., -0.7, 0., 0.75)
+)
+
+
+def add_camera_observables(entity, obs_settings, *camera_specs):
+ """Adds cameras to an entity's worldbody and configures observables for them.
+
+ Args:
+ entity: A `composer.Entity`.
+ obs_settings: An `observations.ObservationSettings` instance.
+ *camera_specs: Instances of `CameraSpec`.
+
+ Returns:
+ A `collections.OrderedDict` keyed on camera names, containing pre-configured
+ `observable.MJCFCamera` instances.
+ """
+ obs_dict = collections.OrderedDict()
+ for spec in camera_specs:
+ camera = entity.mjcf_model.worldbody.add('camera', **spec._asdict())
+ obs = observable.MJCFCamera(camera)
+ obs.configure(**obs_settings.camera._asdict())
+ obs_dict[spec.name] = obs
+ return obs_dict
diff --git a/dm_control/manipulation/shared/constants.py b/dm_control/manipulation/shared/constants.py
new file mode 100644
index 00000000..6c42b0eb
--- /dev/null
+++ b/dm_control/manipulation/shared/constants.py
@@ -0,0 +1,28 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Global constants used in manipulation tasks."""
+
+CONTROL_TIMESTEP = 0.04 # Interval between agent actions, in seconds.
+
+# Predefined RGBA values
+RED = (1., 0., 0., 0.3)
+GREEN = (0., 1., 0., 0.3)
+BLUE = (0., 0., 1., 0.3)
+CYAN = (0., 1., 1., 0.3)
+MAGENTA = (1., 0., 1., 0.3)
+YELLOW = (1., 1., 0., 0.3)
+
+TASK_SITE_GROUP = 3 # Invisible group for task-related sites.
diff --git a/dm_control/manipulation/shared/observations.py b/dm_control/manipulation/shared/observations.py
new file mode 100644
index 00000000..b358ad63
--- /dev/null
+++ b/dm_control/manipulation/shared/observations.py
@@ -0,0 +1,118 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Shared configuration options for observations."""
+
+import collections
+import numpy as np
+
+
+class ObservableSpec(collections.namedtuple(
+ 'ObservableSpec',
+ ['enabled', 'update_interval', 'buffer_size', 'delay', 'aggregator',
+ 'corruptor'])):
+ """Configuration options for generic observables."""
+ __slots__ = ()
+
+
+class CameraObservableSpec(collections.namedtuple(
+ 'CameraObservableSpec', ('height', 'width') + ObservableSpec._fields)):
+ """Configuration options for camera observables."""
+ __slots__ = ()
+
+
+class ObservationSettings(collections.namedtuple(
+ 'ObservationSettings', ['proprio', 'ftt', 'prop_pose', 'camera'])):
+ """Container of `ObservableSpecs` grouped by category."""
+ __slots__ = ()
+
+
+class ObservableNames(collections.namedtuple(
+ 'ObservableNames', ['proprio', 'ftt', 'prop_pose', 'camera'])):
+ """Container that groups the names of observables by category."""
+ __slots__ = ()
+
+ def __new__(cls, proprio=(), ftt=(), prop_pose=(), camera=()):
+ return super(ObservableNames, cls).__new__(
+ cls, proprio=proprio, ftt=ftt, prop_pose=prop_pose, camera=camera)
+
+
+# Global defaults for "feature" observables (i.e. anything that isn't a camera).
+_DISABLED_FEATURE = ObservableSpec(
+ enabled=False,
+ update_interval=1,
+ buffer_size=1,
+ delay=0,
+ aggregator=None,
+ corruptor=None)
+_ENABLED_FEATURE = _DISABLED_FEATURE._replace(enabled=True)
+
+# Force, torque and touch-sensor readings are scaled using a symmetric
+# logarithmic transformation that handles 0 and negative values.
+_symlog1p = lambda x, random_state: np.sign(x) * np.log1p(abs(x))
+_DISABLED_FTT = _DISABLED_FEATURE._replace(corruptor=_symlog1p)
+_ENABLED_FTT = _ENABLED_FEATURE._replace(corruptor=_symlog1p)
+
+# Global defaults for camera observables.
+_DISABLED_CAMERA = CameraObservableSpec(
+ height=84,
+ width=84,
+ enabled=False,
+ update_interval=1,
+ buffer_size=1,
+ delay=0,
+ aggregator=None,
+ corruptor=None)
+_ENABLED_CAMERA = _DISABLED_CAMERA._replace(enabled=True)
+
+
+# Predefined sets of configurations options to apply to each category of
+# observable.
+PERFECT_FEATURES = ObservationSettings(
+ proprio=_ENABLED_FEATURE,
+ ftt=_ENABLED_FTT,
+ prop_pose=_ENABLED_FEATURE,
+ camera=_DISABLED_CAMERA)
+
+VISION = ObservationSettings(
+ proprio=_ENABLED_FEATURE,
+ ftt=_ENABLED_FTT,
+ prop_pose=_DISABLED_FEATURE,
+ camera=_ENABLED_CAMERA)
+
+JACO_ARM_OBSERVABLES = ObservableNames(
+ proprio=['joints_pos', 'joints_vel'], ftt=['joints_torque'])
+JACO_HAND_OBSERVABLES = ObservableNames(
+ proprio=['joints_pos', 'joints_vel', 'pinch_site_pos', 'pinch_site_rmat'])
+FREEPROP_OBSERVABLES = ObservableNames(
+ prop_pose=['position', 'orientation', 'linear_velocity',
+ 'angular_velocity'])
+
+
+def make_options(obs_settings, obs_names):
+ """Constructs a dict of configuration options for a set of named observables.
+
+ Args:
+ obs_settings: An `ObservationSettings` instance.
+ obs_names: An `ObservableNames` instance.
+
+ Returns:
+ A nested dict containing `{observable_name: {option_name: value}}`.
+ """
+ observable_options = {}
+ for category, spec in obs_settings._asdict().items():
+ for observable_name in getattr(obs_names, category):
+ observable_options[observable_name] = spec._asdict()
+ return observable_options
diff --git a/dm_control/manipulation/shared/registry.py b/dm_control/manipulation/shared/registry.py
new file mode 100644
index 00000000..0e43ccda
--- /dev/null
+++ b/dm_control/manipulation/shared/registry.py
@@ -0,0 +1,37 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A global registry of constructors for manipulation environments."""
+
+
+from dm_control.utils import containers
+
+_ALL_CONSTRUCTORS = containers.TaggedTasks(allow_overriding_keys=False)
+
+add = _ALL_CONSTRUCTORS.add
+get_constructor = _ALL_CONSTRUCTORS.__getitem__
+get_all_names = _ALL_CONSTRUCTORS.keys
+get_tags = _ALL_CONSTRUCTORS.tags
+get_names_by_tag = _ALL_CONSTRUCTORS.tagged
+
+# This disables the check that prevents the same task constructor name from
+# being added to the container more than once. This is done in order to allow
+# individual task modules to be reloaded without also reloading `registry.py`
+# first (e.g. when "hot-reloading" environments using IPython's `autoreload`
+# extension).
+
+
+def done_importing_tasks():
+ _ALL_CONSTRUCTORS.allow_overriding_keys = True
diff --git a/dm_control/manipulation/shared/robots.py b/dm_control/manipulation/shared/robots.py
new file mode 100644
index 00000000..8ef4d038
--- /dev/null
+++ b/dm_control/manipulation/shared/robots.py
@@ -0,0 +1,53 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Custom robot constructors with manipulation-specific defaults."""
+
+
+from dm_control.entities.manipulators import kinova
+from dm_control.manipulation.shared import observations
+
+
+# The default position of the base of the arm relative to the origin.
+ARM_OFFSET = (0., 0.4, 0.)
+
+
+def make_arm(obs_settings):
+ """Constructs a robot arm with manipulation-specific defaults.
+
+ Args:
+ obs_settings: `observations.ObservationSettings` instance.
+
+ Returns:
+ An instance of `manipulators.base.RobotArm`.
+ """
+ return kinova.JacoArm(
+ observable_options=observations.make_options(
+ obs_settings, observations.JACO_ARM_OBSERVABLES))
+
+
+def make_hand(obs_settings):
+ """Constructs a robot hand with manipulation-specific defaults.
+
+ Args:
+ obs_settings: `observations.ObservationSettings` instance.
+
+ Returns:
+ An instance of `manipulators.base.RobotHand`.
+ """
+ return kinova.JacoHand(
+ use_pinch_site_as_tcp=True,
+ observable_options=observations.make_options(
+ obs_settings, observations.JACO_HAND_OBSERVABLES))
diff --git a/dm_control/manipulation/shared/tags.py b/dm_control/manipulation/shared/tags.py
new file mode 100644
index 00000000..6c9ad069
--- /dev/null
+++ b/dm_control/manipulation/shared/tags.py
@@ -0,0 +1,22 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""String constants used to annotate task constructors."""
+
+
+FEATURES = 'features'
+VISION = 'vision'
+
+EASY = 'easy'
diff --git a/dm_control/manipulation/shared/workspaces.py b/dm_control/manipulation/shared/workspaces.py
new file mode 100644
index 00000000..260f2972
--- /dev/null
+++ b/dm_control/manipulation/shared/workspaces.py
@@ -0,0 +1,87 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tools for defining and visualizing workspaces for manipulation tasks.
+
+Workspaces define distributions from which the initial positions and/or
+orientations of the hand and prop(s) are sampled, plus other task-specific
+spatial parameters such as target sizes.
+"""
+
+import collections
+from dm_control.composer.variation import distributions
+from dm_control.composer.variation import rotations
+from dm_control.entities.manipulators import base
+from dm_control.manipulation.shared import constants
+import numpy as np
+
+
+_MIN_SITE_DIMENSION = 1e-6 # Ensures that all site dimensions are positive.
+_VISIBLE_GROUP = 0
+_INVISIBLE_GROUP = 3 # Invisible sensor sites live in group 4 by convention.
+
+DOWN_QUATERNION = base.DOWN_QUATERNION
+
+BoundingBox = collections.namedtuple('BoundingBox', ['lower', 'upper'])
+
+uniform_z_rotation = rotations.QuaternionFromAxisAngle(
+ axis=(0., 0., 1.),
+ # NB: We must specify `single_sample=True` here otherwise we will sample a
+ # length-4 array of angles rather than a scalar. This happens because
+ # `PropPlacer` passes in the previous quaternion as `initial_value`,
+ # and by default `distributions.Distribution` assumes that the shape
+ # of the output array should be the same as that of `initial_value`.
+ angle=distributions.Uniform(-np.pi, np.pi, single_sample=True))
+
+
+def add_bbox_site(body, lower, upper, visible=False, **kwargs):
+ """Adds a site for visualizing a bounding box to an MJCF model.
+
+ Args:
+ body: An `mjcf.Element`, the (world)body to which the site should be added.
+ lower: A sequence of lower x,y,z bounds.
+ upper: A sequence of upper x,y,z bounds.
+ visible: Whether the site should be visible by default.
+ **kwargs: Keyword arguments used to set other attributes of the newly
+ created site.
+
+ Returns:
+ An `mjcf.Element` representing the newly created site.
+ """
+ upper = np.array(upper)
+ lower = np.array(lower)
+ pos = (upper + lower) / 2.
+ size = np.maximum((upper - lower) / 2., _MIN_SITE_DIMENSION)
+ group = None if visible else constants.TASK_SITE_GROUP
+ return body.add(
+ 'site', type='box', pos=pos, size=size, group=group, **kwargs)
+
+
+def add_target_site(body, radius, visible=False, **kwargs):
+ """Adds a site for visualizing a target location.
+
+ Args:
+ body: An `mjcf.Element`, the (world)body to which the site should be added.
+ radius: The radius of the target.
+ visible: Whether the site should be visible by default.
+ **kwargs: Keyword arguments used to set other attributes of the newly
+ created site.
+
+ Returns:
+ An `mjcf.Element` representing the newly created site.
+ """
+ group = None if visible else constants.TASK_SITE_GROUP
+ return body.add(
+ 'site', type='sphere', size=[radius], group=group, **kwargs)
diff --git a/dm_control/mjcf/README.md b/dm_control/mjcf/README.md
new file mode 100644
index 00000000..056f4174
--- /dev/null
+++ b/dm_control/mjcf/README.md
@@ -0,0 +1,498 @@
+# PyMJCF
+
+IMPORTANT: If you find yourself stuck while using PyMJCF, check out the various
+IMPORTANT boxes on this page and the [Common gotchas](#common-gotchas) section
+at the bottom to see if any of them is relevant.
+
+This library provides a Python object model for MuJoCo's XML-based
+[MJCF](http://www.mujoco.org/book/modeling.html) physics modeling language. The
+goal of the library is to allow users to easily interact with and modify MJCF
+models in Python, similarly to what the JavaScript DOM does for HTML.
+
+A key feature of this library is the ability to easily compose multiple separate
+MJCF models into a larger one. Disambiguation of duplicated names from different
+models, or multiple instances of the same model, is handled automatically.
+
+The following snippet provides a quick example of this library's typical use
+case. Here, the `UpperBody` class can simply instantiate two copies of `Arm`,
+thus reducing code duplication. The names of bodies, joints, or geoms of each
+`Arm` are automatically prefixed by their parent's names, and so no name
+collision occurs.
+
+```python
+from dm_control import mjcf
+
+class Arm:
+
+ def __init__(self, name):
+ self.mjcf_model = mjcf.RootElement(model=name)
+
+ self.upper_arm = self.mjcf_model.worldbody.add('body', name='upper_arm')
+ self.shoulder = self.upper_arm.add('joint', name='shoulder', type='ball')
+ self.upper_arm.add('geom', name='upper_arm', type='capsule',
+ pos=[0, 0, -0.15], size=[0.045, 0.15])
+
+ self.forearm = self.upper_arm.add('body', name='forearm', pos=[0, 0, -0.3])
+ self.elbow = self.forearm.add('joint', name='elbow',
+ type='hinge', axis=[0, 1, 0])
+ self.forearm.add('geom', name='forearm', type='capsule',
+ pos=[0, 0, -0.15], size=[0.045, 0.15])
+
+class UpperBody:
+
+ def __init__(self):
+ self.mjcf_model = mjcf.RootElement()
+ self.mjcf_model.worldbody.add(
+ 'geom', name='torso', type='box', size=[0.15, 0.045, 0.25])
+ left_shoulder_site = self.mjcf_model.worldbody.add(
+ 'site', size=[1e-6]*3, pos=[-0.15, 0, 0.25])
+ right_shoulder_site = self.mjcf_model.worldbody.add(
+ 'site', size=[1e-6]*3, pos=[0.15, 0, 0.25])
+
+ self.left_arm = Arm(name='left_arm')
+ left_shoulder_site.attach(self.left_arm.mjcf_model)
+
+ self.right_arm = Arm(name='right_arm')
+ right_shoulder_site.attach(self.right_arm.mjcf_model)
+
+body = UpperBody()
+physics = mjcf.Physics.from_mjcf_model(body.mjcf_model)
+```
+
+## Basic operations
+
+### Creating an MJCF model
+
+In PyMJCF, the basic building block of a model is an `mjcf.Element`. This
+corresponds to an element in the generated XML. However, user code _cannot_
+instantiate a generic `mjcf.Element` object directly.
+
+A valid model always consists of a single root `` element. This is
+represented as the special `mjcf.RootElement` type in PyMJCF, which _can_ be
+instantiated in user code to create an empty model.
+
+```python
+from dm_control import mjcf
+
+mjcf_model = mjcf.RootElement()
+print(mjcf_model) # MJCF Element:
+```
+
+### Adding new elements
+
+Attributes of the new element can be passed as kwargs:
+
+```python
+my_box = mjcf_model.worldbody.add('geom', name='my_box',
+ type='box', pos=[0, .1, 0])
+print(my_box) # MJCF Element:
+```
+
+### Parsing an existing XML document
+
+Alternatively, if an existing XML file already exists, PyMJCF can parse it to
+create a Python object:
+
+```python
+from dm_control import mjcf
+
+# Parse from path
+mjcf_model = mjcf.from_path(filename)
+
+# Parse from file
+with open(filename) as f:
+ mjcf_model = mjcf.from_file(f)
+
+# Parse from string
+with open(filename) as f:
+ xml_string = f.read()
+mjcf_model = mjcf.from_xml_string(xml_string)
+
+print(type(mjcf_model)) #
+```
+
+### Traversing through a model
+
+Consider the following MJCF model:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+The child elements and XML attributes of an `Element` object are exposed as
+Python attributes. These attributes all have the same names as their XML
+counterparts, with one exception: the `class` XML attribute is named `dclass` in
+order to avoid a clash with the Python `class` keyword:
+
+```python
+my_geom = mjcf_model.worldbody.body['foo'].body['bar'].geom['my_geom']
+print(isinstance(mjcf_model, mjcf.Element)) # True
+print(my_geom.name) # 'my_geom'
+print(my_geom.pos) # np.array([0., 1., 2.], dtype=float)
+print(my_geom.class) # SyntaxError
+print(my_geom.dclass) # 'brick'
+```
+
+Note that attribute values in the object model are **not** affected by defaults:
+
+```python
+print(mjcf_model.default.default['brick'].geom.rgba) # [1, 0, 0, 1]
+print(my_geom.rgba) # None
+```
+
+### Finding elements without traversing
+
+We can also find elements directly without having to traverse through the object
+hierarchy:
+
+```python
+found_geom = mjcf_model.find('geom', 'my_geom')
+print(found_geom == my_geom) # True
+```
+
+Find all elements of a given type:
+
+```python
+# Note that is also considered a joint
+joints = mjcf_model.find_all('joint')
+print(len(joints)) # 2
+print(joints[0] == mjcf_model.worldbody.body['foo'].freejoint) # True
+print(joints[1] == mjcf_model.worldbody.body['foo'].body['bar'].joint[0]) # True
+```
+
+Note that the order of elements returned by `find_all` is the same as the order
+in which they are declared in the model.
+
+### Modifying XML attributes
+
+Attributes can be modified, added, or removed:
+
+```python
+my_geom.pos = [1, 2, 3]
+print(my_geom.pos) # np.array([1., 2., 3.], dtype=float)
+my_geom.quat = [0, 1, 0, 0]
+print(my_geom.quat) # np.array([0., 1., 0., 0.], dtype=float)
+del my_geom.quat
+print(my_geom.quat) # None
+```
+
+Schema violations result in errors:
+
+```python
+print(my_geom.poss) # raise AttributeError (no child or attribute called poss)
+my_geom.pos = 'invalid' # raise ValueError (assigning string to array)
+my_geom.pos = [1, 2, 3, 4, 5, 6] # raise ValueError (array length is too long)
+
+# raise ValueError (mass is a required attribute of )
+del mjcf_model.find('body', 'foo').inertial.mass
+```
+
+### Uniqueness of identifiers
+
+PyMJCF enforces the uniqueness of "identifier" attributes within a model.
+Identifiers consist of the `class` attribute of a ``, and all `name`
+attributes. Their uniqueness is only enforced within a particular namespace. For
+example, a `` is allowed to have the same name as a ``, whereas
+`` and `` actuators cannot have the same name.
+
+```python
+mjcf_model.worldbody.add('geom', name='my_geom')
+foo = mjcf_model.worldbody.find('body', 'foo')
+foo.add('my_geom') # Error, duplicated geom name
+foo.add('foo') # OK, a geom can have the same name as a body
+mjcf_model.find('geom', 'foo').name = 'my_geom' # Error, duplicated geom name
+```
+
+### Reference attributes
+
+Some attributes are references to other elements. For example, the `joint`
+attribute of an actuator refers to a `` element in the model.
+
+An `mjcf.Element` can be directly assigned to these reference attributes:
+
+```python
+my_hinge = mjcf_model.find('joint', 'my_hinge')
+my_actuator = mjcf_model.actuator.add('velocity', joint=my_hinge)
+```
+
+This is the recommended way to assign reference attributes, since it guarantees
+that the reference is not invalidated if the referenced element is renamed.
+Alternatively, a string can also be assigned to reference attributes. In this
+case, PyMJCF does **not** attempt to verify that the named element actually
+exists in the model.
+
+IMPORTANT: If the element being referenced is in a different model to the
+reference attribute (e.g. in an attached model), the reference **must** be
+created by directly assigning an `mjcf.Element` object to the attribute rather
+than a string. Strings assigned to reference attributes cannot contain '/',
+since they are automatically scoped by PyMJCF upon attachment.
+
+## Attaching models
+
+In this section we will refer to an `mjcf.RootElement` simply as a "model".
+Models can be _attached_ to other models in order to create compositional
+scenes.
+
+```python
+arena = mjcf.RootElement()
+arena.worldbody.add('geom', name='ground', type='plane', size=[10, 10, 1])
+
+robot = mjcf.from_xml_file('robot.xml')
+arena.attach(robot)
+```
+
+We refer to `arena` as the _parent model_, and `robot` as the _child model_ (or
+the _attached model_).
+
+### Attachment frames
+
+When a model is attached to a site, an empty body is created in the parent
+model. This empty body is called an _attachment frame_.
+
+The attachment frame is created as a child of the body that contains the
+attachment site, and it has the same position and orientation as the site. When
+the XML is generated, the attachment frame's contents shadow the contents of the
+attached model's ``. The attachment frame's name in the generated XML
+is the child's `fully/qualified/prefix/`. The trailing slash ensures that the
+attachment frame's name never collides with a user-defined body.
+
+More concretely, if we have the following parent and child models:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+Then the final generated XML will be:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+IMPORTANT: The attachment frame is created _transparently_ to the user. In
+particular, it is NOT treated as a regular `body` by PyMJCF. Its name in the
+generated XML should be considered implementation detail and should NOT be
+relied on.
+
+Having said that, it is sometimes necessary to access the attachment frame, for
+example to add a joint between the parent and the child model. The easiest way
+to do this is to hold a reference to the object returned by a call to `attach`:
+
+```python
+attachment_frame = parent_model.attach('child')
+attachment_frame.add('freejoint')
+```
+
+Alternatively, if a model has already been attached, the `find` function can be
+used with the `attachment_frame` namespace in order to retrieve the attachment
+frame. The `get_attachment_frame` convenience function in `mjcf.traversal_utils`
+can find the child model's attachment frame without needing access to the parent
+model.
+
+```python
+frame_1 = parent_model.find('attachment_frame', 'child')
+
+# Convenience function: get the attachment frame directly from a child model
+frame_2 = mjcf.traversal_utils.get_attachment_frame(child_model)
+print(frame_1 == frame_2) # True
+```
+
+IMPORTANT: To encourage good modeling practices, the only allowed direct
+children of an attachment frame are `` and ``. Other types of
+elements should instead add be added to the `` of the attached model.
+
+### Element ownership
+
+IMPORTANT: Elements of child models do **not** appear when traversing through
+the parent model.
+
+### Default classes
+
+PyMJCF ensures that default classes of a parent model _never_ affect any of its
+child models. This minimises the possibility that two models become subtly
+"incompatible", as a model always behaves in the same way regardless of what it
+is attached to.
+
+The way that PyMJCF achieves this in practice is to move everything in a model's
+global `` context into a default class named `/`. In other words, a
+PyMJCF-generated model never has anything in the global default context.
+Instead, the generated model always looks like:
+
+```xml
+
+
+
+
+
+
+
+
+
+```
+
+IMPORTANT: This transformation is _transparent_ to the user. Within Python, the
+above geom rgba setting is accessed as if it were a global default, i.e.
+`mjcf_model.default.geom.rgba`. Generally speaking, users should never have to
+worry about PyMJCF's internal handling of defaults.
+
+When a model is attached, its `/` default class turns into
+`fully/qualified/prefix/`. The trailing slash ensures that this transformation
+never conflicts with a user-named default class. More specifically, if we have
+the following parent and child models:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+Then the final generated XML will be:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+### Global options
+
+A model cannot be attached to another model if _any_ of the global options are
+different. Global options consist of attributes of ``, `"""
- with mock.patch.object(core, "logging") as mock_logging:
- core.MjModel.from_xml_string(xml_with_warning)
- mock_logging.warn.assert_called_once_with(
- "Error: Pre-allocated constraint buffer is full. "
- "Increase njmax above 2. Time = 0.0000.")
+
+ # This model should compile successfully, but raise a warning on the first
+ # simulation step.
+ model = core.MjModel.from_xml_string(xml_with_warning)
+ data = core.MjData(model)
+ mujoco.mj_step(model.ptr, data.ptr)
def testLoadXMLWithAssetsFromString(self):
core.MjModel.from_xml_string(MODEL_WITH_ASSETS, assets=ASSETS)
- with self.assertRaises(core.Error):
+ with self.assertRaises(ValueError):
# Should fail to load without the assets
core.MjModel.from_xml_string(MODEL_WITH_ASSETS)
- def testVFSFilenameTooLong(self):
- limit = core._MAX_VFS_FILENAME_CHARACTERS
- contents = "fake contents"
- valid_filename = "a" * limit
- with core._temporary_vfs({valid_filename: contents}):
- pass
- invalid_filename = "a" * (limit + 1)
- expected_message = core._VFS_FILENAME_TOO_LONG.format(
- length=(limit + 1), limit=limit, filename=invalid_filename)
- with self.assertRaisesWithLiteralMatch(ValueError, expected_message):
- with core._temporary_vfs({invalid_filename: contents}):
- pass
-
def testSaveLastParsedModelToXML(self):
save_xml_path = os.path.join(OUT_DIR, "tmp_humanoid.xml")
@@ -170,16 +134,16 @@ def testDimensions(self):
def testStep(self):
t0 = self.data.time
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
- self.assert_(np.all(np.isfinite(self.data.qpos[:])))
- self.assert_(np.all(np.isfinite(self.data.qvel[:])))
+ self.assertTrue(np.all(np.isfinite(self.data.qpos[:])))
+ self.assertTrue(np.all(np.isfinite(self.data.qvel[:])))
def testMultipleData(self):
data2 = core.MjData(self.model)
self.assertNotEqual(self.data.ptr, data2.ptr)
t0 = self.data.time
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
self.assertEqual(data2.time, 0)
@@ -194,7 +158,7 @@ def testModelName(self):
@parameterized.named_parameters(
("_copy", lambda x: x.copy()),
- ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ ("_pickle_unpickle", lambda x: pickle.loads(pickle.dumps(x))),)
def testCopyOrPickleModel(self, func):
timestep = 0.12345
self.model.opt.timestep = timestep
@@ -207,40 +171,40 @@ def testCopyOrPickleModel(self, func):
@parameterized.named_parameters(
("_copy", lambda x: x.copy()),
- ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ ("_pickle_unpickle", lambda x: pickle.loads(pickle.dumps(x))),)
def testCopyOrPickleData(self, func):
- for _ in xrange(10):
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ for _ in range(10):
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
data2 = func(self.data)
attr_to_compare = ("time", "energy", "qpos", "xpos")
self.assertNotEqual(data2.ptr, self.data.ptr)
self._assert_attributes_equal(data2, self.data, attr_to_compare)
- for _ in xrange(10):
- mjlib.mj_step(self.model.ptr, self.data.ptr)
- mjlib.mj_step(data2.model.ptr, data2.ptr)
+ for _ in range(10):
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(data2.model.ptr, data2.ptr)
self._assert_attributes_equal(data2, self.data, attr_to_compare)
@parameterized.named_parameters(
("_copy", lambda x: x.copy()),
- ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ ("_pickle_unpickle", lambda x: pickle.loads(pickle.dumps(x))),)
def testCopyOrPickleStructs(self, func):
- for _ in xrange(10):
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ for _ in range(10):
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
data2 = func(self.data)
self.assertNotEqual(data2.ptr, self.data.ptr)
- for name in ["warning", "timer", "solver"]:
- self._assert_structs_equal(getattr(self.data, name), getattr(data2, name))
- for _ in xrange(10):
- mjlib.mj_step(self.model.ptr, self.data.ptr)
- mjlib.mj_step(data2.model.ptr, data2.ptr)
- for expected, actual in zip(self.data.timer, data2.timer):
- self._assert_structs_equal(expected, actual)
+ attr_to_compare = ("warning", "solver")
+ self._assert_attributes_equal(self.data, data2, attr_to_compare)
+ for _ in range(10):
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(data2.model.ptr, data2.ptr)
+ self._assert_attributes_equal(self.data, data2, attr_to_compare)
@parameterized.parameters(
("right_foot", "body", 6),
- ("right_foot", enums.mjtObj.mjOBJ_BODY, 6),
+ ("right_foot", mujoco.mjtObj.mjOBJ_BODY, 6),
("left_knee", "joint", 11),
- ("left_knee", enums.mjtObj.mjOBJ_JOINT, 11))
+ ("left_knee", mujoco.mjtObj.mjOBJ_JOINT, 11),
+ )
def testNamesIds(self, name, object_type, object_id):
output_id = self.model.name2id(name, object_type)
self.assertEqual(object_id, output_id)
@@ -248,9 +212,9 @@ def testNamesIds(self, name, object_type, object_id):
self.assertEqual(name, output_name)
def testNamesIdsExceptions(self):
- with self.assertRaisesRegexp(core.Error, "does not exist"):
+ with self.assertRaisesRegex(core.Error, "does not exist"):
self.model.name2id("nonexistent_body_name", "body")
- with self.assertRaisesRegexp(core.Error, "is not a valid object type"):
+ with self.assertRaisesRegex(core.Error, "is not a valid object type"):
self.model.name2id("right_foot", "nonexistent_type_name")
def testNamelessObject(self):
@@ -258,20 +222,6 @@ def testNamelessObject(self):
name = self.model.id2name(0, "camera")
self.assertEqual("", name)
- def testWarningCallback(self):
- self.data.qpos[0] = np.inf
- with mock.patch.object(core, "logging") as mock_logging:
- mjlib.mj_step(self.model.ptr, self.data.ptr)
- mock_logging.warn.assert_called_once_with(
- "Nan, Inf or huge value in QPOS at DOF 0. The simulation is unstable. "
- "Time = 0.0000.")
-
- def testErrorCallback(self):
- with mock.patch.object(core, "logging") as mock_logging:
- mjlib.mj_activate(b"nonexistent_activation_key")
- mock_logging.fatal.assert_called_once_with(
- "Could not open activation key file nonexistent_activation_key")
-
def testSingleCallbackContext(self):
callback_was_called = [False]
@@ -279,7 +229,7 @@ def testSingleCallbackContext(self):
def callback(unused_model, unused_data):
callback_was_called[0] = True
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertFalse(callback_was_called[0])
class DummyError(RuntimeError):
@@ -289,7 +239,7 @@ class DummyError(RuntimeError):
with core.callback_context("mjcb_passive", callback):
# Stepping invokes the `mjcb_passive` callback.
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertTrue(callback_was_called[0])
# Exceptions should not prevent `mjcb_passive` from being reset.
@@ -300,7 +250,7 @@ class DummyError(RuntimeError):
# `mjcb_passive` should have been reset to None.
callback_was_called[0] = False
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertFalse(callback_was_called[0])
def testNestedCallbackContexts(self):
@@ -318,24 +268,24 @@ def inner(unused_model, unused_data):
with core.callback_context("mjcb_passive", outer):
# This should execute `outer` a few times.
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertEqual(last_called[0], outer_called)
with core.callback_context("mjcb_passive", inner):
# This should execute `inner` a few times.
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertEqual(last_called[0], inner_called)
# When we exit the inner context, the `mjcb_passive` callback should be
# reset to `outer`.
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertEqual(last_called[0], outer_called)
# When we exit the outer context, the `mjcb_passive` callback should be
# reset to None, and stepping should not affect `last_called`.
last_called[0] = None
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
self.assertIsNone(last_called[0])
def testDisableFlags(self):
@@ -357,8 +307,8 @@ def testDisableFlags(self):
"""
model = core.MjModel.from_xml_string(xml_string)
data = core.MjData(model)
- for _ in xrange(100): # Let the simulation settle for a while.
- mjlib.mj_step(model.ptr, data.ptr)
+ for _ in range(100): # Let the simulation settle for a while.
+ mujoco.mj_step(model.ptr, data.ptr)
# With gravity and contact enabled, the cube should be stationary and the
# touch sensor should give a reading of ~9.81 N.
@@ -368,56 +318,181 @@ def testDisableFlags(self):
# If we disable both contacts and gravity then the cube should remain
# stationary and the touch sensor should read zero.
with model.disable("contact", "gravity"):
- mjlib.mj_step(model.ptr, data.ptr)
+ mujoco.mj_step(model.ptr, data.ptr)
self.assertAlmostEqual(data.qvel[0], 0, places=4)
self.assertEqual(data.sensordata[0], 0)
# If we disable contacts but not gravity then the cube should fall through
# the floor.
- with model.disable(enums.mjtDisableBit.mjDSBL_CONTACT):
- for _ in xrange(10):
- mjlib.mj_step(model.ptr, data.ptr)
+ with model.disable(mujoco.mjtDisableBit.mjDSBL_CONTACT):
+ for _ in range(10):
+ mujoco.mj_step(model.ptr, data.ptr)
self.assertLess(data.qvel[0], -0.1)
def testDisableFlagsExceptions(self):
- with self.assertRaisesRegexp(ValueError, "not a valid flag name"):
+ with self.assertRaises(ValueError):
with self.model.disable("invalid_flag_name"):
pass
- with self.assertRaisesRegexp(ValueError,
- "not a value in `enums.mjtDisableBit`"):
+ with self.assertRaises(ValueError):
with self.model.disable(-99):
pass
- @parameterized.named_parameters(
- ("MjModel",
- lambda _: core.MjModel.from_xml_path(HUMANOID_XML_PATH),
- "mj_deleteModel"),
- ("MjData",
- lambda self: core.MjData(self.model),
- "mj_deleteData"),
- ("MjvScene",
- lambda _: core.MjvScene(),
- "mjv_freeScene"))
- def testFree(self, constructor, destructor_name):
- for _ in xrange(5):
- destructor = getattr(mjlib, destructor_name)
- with mock.patch.object(
- core.mjlib, destructor_name, wraps=destructor) as mock_destructor:
- wrapper = constructor(self)
-
- expected_address = ctypes.addressof(wrapper.ptr.contents)
- wrapper.free()
- self.assertIsNone(wrapper.ptr)
-
- mock_destructor.assert_called_once()
- pointer = mock_destructor.call_args[0][0]
- actual_address = ctypes.addressof(pointer.contents)
- self.assertEqual(expected_address, actual_address)
+ @parameterized.parameters(
+ # The tip is .5 meters from the cart so we expect its horizontal velocity
+ # to be 1m/s + .5m*1rad/s = 1.5m/s.
+ dict(
+ qpos=[0., 0.], # Pole pointing upwards.
+ qvel=[1., 1.],
+ expected_linvel=[1.5, 0., 0.],
+ expected_angvel=[0., 1., 0.],
+ ),
+ # For the same velocities but with the pole pointing down, we expect the
+ # velocities to cancel, making the global tip velocity now equal to
+ # 1m/s - 0.5m*1rad/s = 0.5m/s.
+ dict(
+ qpos=[0., np.pi], # Pole pointing downwards.
+ qvel=[1., 1.],
+ expected_linvel=[0.5, 0., 0.],
+ expected_angvel=[0., 1., 0.],
+ ),
+ # In the site's local frame, which is now flipped w.r.t the world, the
+ # velocity is in the negative x direction.
+ dict(
+ qpos=[0., np.pi], # Pole pointing downwards.
+ qvel=[1., 1.],
+ expected_linvel=[-0.5, 0., 0.],
+ expected_angvel=[0., 1., 0.],
+ local=True,
+ ),
+ )
+ def testObjectVelocity(
+ self, qpos, qvel, expected_linvel, expected_angvel, local=False):
+ cartpole = """
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ model = core.MjModel.from_xml_string(cartpole)
+ data = core.MjData(model)
+ data.qpos[:] = qpos
+ data.qvel[:] = qvel
+ mujoco.mj_step1(model.ptr, data.ptr)
+ linvel, angvel = data.object_velocity("mass", "geom", local_frame=local)
+ np.testing.assert_array_almost_equal(linvel, expected_linvel)
+ np.testing.assert_array_almost_equal(angvel, expected_angvel)
+
+ def testContactForce(self):
+ box_on_floor = """
+
+
+
+
+
+
+
+
+
+ """
+ model = core.MjModel.from_xml_string(box_on_floor)
+ data = core.MjData(model)
+ # Settle for 500 timesteps (1 second):
+ for _ in range(500):
+ mujoco.mj_step(model.ptr, data.ptr)
+ normal_force = 0.
+ for contact_id in range(data.ncon):
+ force = data.contact_force(contact_id)
+ normal_force += force[0, 0]
+ box_id = 1
+ box_weight = -model.opt.gravity[2]*model.body_mass[box_id]
+ self.assertAlmostEqual(normal_force, box_weight)
+ # Test raising of out-of-range errors:
+ bad_ids = [-1, data.ncon]
+ for bad_id in bad_ids:
+ with self.assertRaisesWithLiteralMatch(
+ ValueError,
+ core._CONTACT_ID_OUT_OF_RANGE.format(
+ max_valid=data.ncon - 1, actual=bad_id)):
+ data.contact_force(bad_id)
+ @parameterized.parameters(
+ dict(
+ condim=3, # Only sliding friction.
+ expected_torques=[False, False, False], # No torques.
+ ),
+ dict(
+ condim=4, # Sliding and torsional friction.
+ expected_torques=[True, False, False], # Only torsional torque.
+ ),
+ dict(
+ condim=6, # Sliding, torsional and rolling.
+ expected_torques=[True, True, True], # All torques are nonzero.
+ ),
+ )
+ def testContactTorque(self, condim, expected_torques):
+ ball_on_floor = """
+
+
+
+
+
+
+
+
+
+ """
+ model = core.MjModel.from_xml_string(ball_on_floor)
+ data = core.MjData(model)
+ model.geom_condim[:] = condim
+ data.qvel[3:] = np.array((1., 1., 1.))
+ # Settle for 10 timesteps (20 milliseconds):
+ for _ in range(10):
+ mujoco.mj_step(model.ptr, data.ptr)
+ contact_id = 0 # This model has only one contact.
+ _, torque = data.contact_force(contact_id)
+ nonzero_torques = torque != 0
+ np.testing.assert_array_equal(nonzero_torques, np.array((expected_torques)))
+
+ def testFreeMjrContext(self):
+ for _ in range(5):
+ renderer = _render.Renderer(640, 480)
+ mjr_context = core.MjrContext(self.model, renderer)
# Explicit freeing should not break any automatic GC triggered later.
- del wrapper
+ del mjr_context
+ renderer.free()
+ del renderer
gc.collect()
+ def testSceneGeomsAttribute(self):
+ scene = core.MjvScene(model=self.model)
+ self.assertEqual(scene.ngeom, 0)
+ self.assertEmpty(scene.geoms)
+ geom_types = (
+ mujoco.mjtObj.mjOBJ_BODY,
+ mujoco.mjtObj.mjOBJ_GEOM,
+ mujoco.mjtObj.mjOBJ_SITE,
+ )
+ for geom_type in geom_types:
+ scene.ngeom += 1
+ scene.geoms[scene.ngeom - 1].objtype = geom_type
+ self.assertLen(scene.geoms, len(geom_types))
+ self.assertEqual(tuple(g.objtype for g in scene.geoms), geom_types)
+
+ def testInvalidFontScale(self):
+ invalid_font_scale = 99
+ with self.assertRaises(ValueError):
+ core.MjrContext(model=self.model,
+ gl_context=None, # Don't need a context for this test.
+ font_scale=invalid_font_scale)
+
def _get_attributes_test_params():
model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
@@ -428,7 +503,8 @@ def _get_attributes_test_params():
array_args = []
scalar_args = []
skipped_args = []
- for parent_name, parent_obj in zip(("model", "data"), (model, data)):
+ for parent_name, parent_obj in zip(("model", "data"),
+ (model._model, data._data)):
for attr_name in dir(parent_obj):
if not attr_name.startswith("_"): # Skip 'private' attributes
args = (parent_name, attr_name)
@@ -459,15 +535,26 @@ def testReadWriteArray(self, parent_name, attr_name):
raise TypeError("{}.{} has incorrect type {!r} - must be one of {!r}."
.format(parent_name, attr_name, type(attr), ARRAY_TYPES))
# Check that we can read the contents of the array
- old_contents = attr[:]
- # Don't write to integer arrays since these might contain pointers.
- if not np.issubdtype(old_contents.dtype, int):
- # Write unique values to the array, check that we can read them back.
- new_contents = np.arange(old_contents.size, dtype=old_contents.dtype)
- new_contents.shape = old_contents.shape
- attr[:] = new_contents
- np.testing.assert_array_equal(new_contents, attr[:])
- self._take_steps() # Take a few steps, check that we don't get segfaults.
+ _ = attr[:]
+
+ # Write unique values into the array and read them back.
+ self._write_unique_values(attr_name, attr)
+ self._take_steps() # Take a few steps, check that we don't get segfaults.
+
+ def _write_unique_values(self, attr_name, target_array):
+ # If the target array is structured, recursively write unique values into
+ # each subfield.
+ if target_array.dtype.fields is not None:
+ for field_name in target_array.dtype.fields:
+ self._write_unique_values(attr_name, target_array[field_name])
+ # Don't write to integer arrays since these might contain pointers. Also
+ # don't write directly into the stack.
+ elif (attr_name != "stack"
+ and not np.issubdtype(target_array.dtype, np.integer)):
+ new_contents = np.arange(target_array.size, dtype=target_array.dtype)
+ new_contents.shape = target_array.shape
+ target_array[:] = new_contents
+ np.testing.assert_array_equal(new_contents, target_array[:])
@parameterized.parameters(*_scalar_args)
def testReadWriteScalar(self, parent_name, attr_name):
@@ -492,12 +579,13 @@ def testSkipped(self, *unused_args):
pass
def setUp(self):
+ super().setUp()
self.model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
self.data = core.MjData(self.model)
def _take_steps(self, n=5):
- for _ in xrange(n):
- mjlib.mj_step(self.model.ptr, self.data.ptr)
+ for _ in range(n):
+ mujoco.mj_step(self.model.ptr, self.data.ptr)
if __name__ == "__main__":
diff --git a/dm_control/mujoco/wrapper/mjbindings/__init__.py b/dm_control/mujoco/wrapper/mjbindings/__init__.py
index 8e248e17..57621b20 100644
--- a/dm_control/mujoco/wrapper/mjbindings/__init__.py
+++ b/dm_control/mujoco/wrapper/mjbindings/__init__.py
@@ -15,22 +15,21 @@
"""Import core names of MuJoCo ctypes bindings."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
from absl import logging
from dm_control.mujoco.wrapper.mjbindings import constants
from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.mujoco.wrapper.mjbindings import sizes
-from dm_control.mujoco.wrapper.mjbindings import types
-from dm_control.mujoco.wrapper.mjbindings import wrappers
+
+# Internal analytics import.
# pylint: disable=g-import-not-at-top
try:
from dm_control.mujoco.wrapper.mjbindings import functions
from dm_control.mujoco.wrapper.mjbindings.functions import mjlib
+ logging.info('MuJoCo library version is: %s', mjlib.mj_versionString())
+ # Internal analytics.
except (IOError, OSError):
- logging.warn('mjbindings failed to import mjlib and other functions. '
- 'libmujoco.so may not be accessible.')
+ logging.warning('mjbindings failed to import mjlib and other functions. '
+ 'libmujoco.so may not be accessible.')
diff --git a/dm_control/mujoco/wrapper/mjbindings/functions.py b/dm_control/mujoco/wrapper/mjbindings/functions.py
new file mode 100755
index 00000000..8cd1c3ba
--- /dev/null
+++ b/dm_control/mujoco/wrapper/mjbindings/functions.py
@@ -0,0 +1,31 @@
+# Copyright 2022 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Aliases for the mujoco library, provided for backwards compatibility.
+
+New code should import mujoco directly, instead of accessing these constants or
+mjlib through this module.
+"""
+import mujoco
+
+mjlib = mujoco
+
+mjDISABLESTRING = mujoco.mjDISABLESTRING
+mjENABLESTRING = mujoco.mjENABLESTRING
+mjTIMERSTRING = mujoco.mjTIMERSTRING
+mjLABELSTRING = mujoco.mjLABELSTRING
+mjFRAMESTRING = mujoco.mjFRAMESTRING
+mjVISSTRING = mujoco.mjVISSTRING
+mjRNDSTRING = mujoco.mjRNDSTRING
diff --git a/dm_control/mujoco/wrapper/mjbindings_test.py b/dm_control/mujoco/wrapper/mjbindings_test.py
index cd959a50..9dae5803 100644
--- a/dm_control/mujoco/wrapper/mjbindings_test.py
+++ b/dm_control/mujoco/wrapper/mjbindings_test.py
@@ -15,15 +15,8 @@
"""Tests for mjbindings."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
from absl.testing import absltest
from absl.testing import parameterized
-
from dm_control.mujoco.wrapper.mjbindings import constants
from dm_control.mujoco.wrapper.mjbindings import sizes
@@ -35,7 +28,6 @@ class MjbindingsTest(parameterized.TestCase):
('mjmodel', 'geom_type', ('ngeom',)),
# Fields with identifiers in mjxmacro that are resolved at compile-time.
('mjmodel', 'actuator_dynprm', ('nu', constants.mjNDYN)),
- ('mjdata', 'efc_solref', ('njmax', constants.mjNREF)),
# Fields with multiple named indices.
('mjmodel', 'key_qpos', ('nkey', 'nq')),
)
diff --git a/dm_control/mujoco/wrapper/mujoco_timer.lds b/dm_control/mujoco/wrapper/mujoco_timer.lds
new file mode 100644
index 00000000..f64aa856
--- /dev/null
+++ b/dm_control/mujoco/wrapper/mujoco_timer.lds
@@ -0,0 +1,6 @@
+{
+ global:
+ dm_control_mujoco_get_time;
+ local:
+ *;
+};
diff --git a/dm_control/mujoco/wrapper/util.py b/dm_control/mujoco/wrapper/util.py
index 1a85abde..c03e69cc 100644
--- a/dm_control/mujoco/wrapper/util.py
+++ b/dm_control/mujoco/wrapper/util.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The dm_control Authors.
+# Copyright 2017-2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,206 +15,35 @@
"""Various helper functions and classes."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import ctypes
-import ctypes.util
import functools
-import os
-import platform
import sys
-import threading
-# Internal dependencies.
+import mujoco
import numpy as np
-import six
-
-from dm_control.utils import resources
-# Environment variables that can be used to override the default paths to the
-# MuJoCo shared library and key file.
+# Environment variable that can be used to override the default path to the
+# MuJoCo shared library.
ENV_MJLIB_PATH = "MJLIB_PATH"
-ENV_MJKEY_PATH = "MJKEY_PATH"
-
-
-MJLIB_NAME = "mujoco150"
-
-
-def _get_shared_library_filename():
- try:
- libc_path = ctypes.util.find_library("c")
- libc_filename = os.path.split(libc_path)[1]
- prefix = "lib" if libc_filename.startswith("lib") else ""
- extension = libc_filename.split(".")[1]
- except (AttributeError, IndexError):
- prefix = "lib"
- extension = "so"
- return "{}{}.{}".format(prefix, MJLIB_NAME, extension)
-
-
-DEFAULT_MJLIB_PATH = os.path.join(
- "~/.mujoco/mjpro150/bin", _get_shared_library_filename())
-DEFAULT_MJKEY_PATH = "~/.mujoco/mjkey.txt"
-
DEFAULT_ENCODING = sys.getdefaultencoding()
def to_binary_string(s):
"""Convert text string to binary."""
- if isinstance(s, six.binary_type):
+ if isinstance(s, bytes):
return s
return s.encode(DEFAULT_ENCODING)
def to_native_string(s):
"""Convert a text or binary string to the native string format."""
- if six.PY3 and isinstance(s, six.binary_type):
+ if isinstance(s, bytes):
return s.decode(DEFAULT_ENCODING)
- elif six.PY2 and isinstance(s, six.text_type):
- return s.encode(DEFAULT_ENCODING)
else:
return s
-def _get_full_path(path):
- expanded_path = os.path.expanduser(os.path.expandvars(path))
- return resources.GetResourceFilename(expanded_path)
-
-
def get_mjlib():
- """Loads `libmujoco.so` and returns it as a `ctypes.CDLL` object."""
- try:
- # Use the MJLIB_PATH environment variable if it has been set.
- raw_path = os.environ[ENV_MJLIB_PATH]
- except KeyError:
- paths_to_try = [
- # If libmujoco is in LD_LIBRARY_PATH then ctypes only needs its name.
- os.path.basename(DEFAULT_MJLIB_PATH),
- _get_full_path(DEFAULT_MJLIB_PATH),
- ]
- for library_path in paths_to_try:
- try:
- return ctypes.cdll.LoadLibrary(library_path)
- except OSError as e:
- if "undefined symbol" in str(e) and platform.system() == "Linux":
- # This means that we've found MuJoCo but haven't loaded GLEW.
- ctypes.CDLL(ctypes.util.find_library("GL"), ctypes.RTLD_GLOBAL)
- ctypes.CDLL(ctypes.util.find_library("GLEW"), ctypes.RTLD_GLOBAL)
- return ctypes.cdll.LoadLibrary(library_path)
- raw_path = DEFAULT_MJLIB_PATH
- return ctypes.cdll.LoadLibrary(_get_full_path(raw_path))
-
-
-def get_mjkey_path():
- """Returns a path to the MuJoCo key file."""
- raw_path = os.environ.get(ENV_MJKEY_PATH, DEFAULT_MJKEY_PATH)
- return _get_full_path(raw_path)
-
-
-class WrapperBase(object):
- """Base class for wrappers that provide getters/setters for ctypes structs."""
-
- # This is needed so that the __del__ methods of MjModel and MjData can still
- # succeed in cases where an exception occurs during __init__() before the _ptr
- # attribute has been assigned.
- _ptr = None
-
- def __init__(self, ptr, model=None):
- """Constructs a wrapper instance from a `ctypes.Structure`.
-
- Args:
- ptr: `ctypes.POINTER` to the struct to be wrapped.
- model: `MjModel` instance; needed by `MjDataWrapper` in order to get the
- dimensions of dynamically-sized arrays at runtime.
- """
- self._ptr = ptr
- self._model = model
-
- @property
- def ptr(self):
- """Pointer to the underlying `ctypes.Structure` instance."""
- return self._ptr
-
-
-class CachedProperty(property):
- """A property that is evaluated only once per object instance."""
-
- def __init__(self, func, doc=None):
- super(CachedProperty, self).__init__(fget=func, doc=doc)
- self.lock = threading.RLock()
-
- def __get__(self, obj, cls):
- if obj is None:
- return self
- name = self.fget.__name__
- obj_dict = obj.__dict__
- with self.lock:
- try:
- # Return cached result if it was computed before the lock was acquired
- return obj_dict[name]
- except KeyError:
- # Otherwise call the function, cache the result, and return it
- return obj_dict.setdefault(name, self.fget(obj))
-
-
-# It's easy to create numpy arrays from a pointer then have these persist after
-# the model has been destroyed and its underlying memory freed. To mitigate the
-# risk of writing to a pointer after it has been freed, all array attributes are
-# read-only by default. In order to write to them you need to explicitly set
-# their ".writeable" flag to True (the SetFlags context manager above provides
-# a convenient way to do this).
-
-# The proper solution would be to prevent the model from being garbage-collected
-# whilst any of the views onto its buffers are still alive.
-
-
-def _as_array(src, shape):
- """Converts a native `src` array to a managed numpy buffer.
-
- Args:
- src: A ctypes pointer or array.
- shape: A tuple specifying the dimensions of the output array.
-
- Returns:
- A numpy array.
- """
-
- # To work around a memory leak in numpy, we have to go through this
- # frombuffer method instead of calling ctypeslib.as_array. See
- # https://github.com/numpy/numpy/issues/6511
- # return np.ctypeslib.as_array(src, shape)
-
- # This is part of the public API. See
- # http://git.net/ml/python.ctypes/2008-02/msg00014.html
- ctype = src._type_ # pylint: disable=protected-access
-
- size = np.product(shape)
- ptr = ctypes.cast(src, ctypes.POINTER(ctype * size))
- buf = np.frombuffer(ptr.contents, dtype=ctype)
- buf.shape = shape
- return buf
-
-
-def buf_to_npy(src, shape, np_dtype=None):
- """Returns a numpy array view of the contents of a ctypes pointer or array.
-
- Args:
- src: A ctypes pointer or array.
- shape: A tuple specifying the dimensions of the output array.
- np_dtype: A string or `np.dtype` object specifying the dtype of the output
- array. If None, the dtype is inferred from the type of `src`.
-
- Returns:
- A numpy array.
- """
- # This causes a harmless RuntimeWarning about mismatching buffer format
- # strings due to a bug in ctypes: http://stackoverflow.com/q/4964101/1461210
- arr = _as_array(src, shape)
- if np_dtype is not None:
- arr.dtype = np_dtype
- return arr
+ return mujoco
@functools.wraps(np.ctypeslib.ndpointer)
diff --git a/dm_control/mujoco/wrapper/util_test.py b/dm_control/mujoco/wrapper/util_test.py
deleted file mode 100644
index 1d6cf982..00000000
--- a/dm_control/mujoco/wrapper/util_test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Tests for util."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import resource
-
-# Internal dependencies.
-
-from absl.testing import absltest
-
-from dm_control.mujoco.wrapper import core
-from dm_control.mujoco.wrapper import util
-
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-_NUM_CALLS = 10000
-_RSS_GROWTH_TOLERANCE = 150 # Bytes
-
-
-class UtilTest(absltest.TestCase):
-
- def test_buf_to_npy_no_memory_leak(self):
- """Ensures we can call buf_to_npy without leaking memory."""
- model = core.MjModel.from_xml_string("")
- src = model._ptr.contents.name_geomadr
- shape = (model.ngeom,)
-
- # This uses high water marks to find memory leaks in native code.
- old_max = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
- for _ in xrange(_NUM_CALLS):
- buf = util.buf_to_npy(src, shape)
- del buf
- new_max = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
- growth = new_max - old_max
-
- if growth > _RSS_GROWTH_TOLERANCE:
- self.fail("RSS grew by {} bytes, exceeding tolerance of {} bytes."
- .format(growth, _RSS_GROWTH_TOLERANCE))
-
-if __name__ == "__main__":
- absltest.main()
diff --git a/dm_control/render/__init__.py b/dm_control/render/__init__.py
deleted file mode 100644
index 1c1a0953..00000000
--- a/dm_control/render/__init__.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""OpenGL context management for rendering MuJoCo scenes.
-
-The `Renderer` class will use one of the following rendering APIs, in order of
-descending priority: EGL > GLFW > OSMesa.
-
-Rendering support can be disabled globally by setting the
-`DISABLE_MUJOCO_RENDERING` environment variable before launching the Python
-interpreter. This allows the MuJoCo bindings in `dm_control.mujoco` to be used
-on platforms where an OpenGL context cannot be created. Attempting to render
-when rendering has been disabled will result in a `RuntimeError`.
-"""
-
-import os
-DISABLED = bool(os.environ.get('DISABLE_MUJOCO_RENDERING', ''))
-del os
-
-DISABLED_MESSAGE = (
- 'Rendering support has been disabled by the `DISABLE_MUJOCO_RENDERING` '
- 'environment variable')
-
-# pylint: disable=g-import-not-at-top
-# pylint: disable=invalid-name
-
-_GLFWRenderer = None
-_EGLRenderer = None
-_OSMesaRenderer = None
-
-if not DISABLED:
- try:
- from dm_control.render.glfw_renderer import GLFWContext as _GLFWRenderer
- except ImportError:
- pass
- try:
- from dm_control.render.egl_renderer import EGLContext as _EGLRenderer
- except ImportError:
- pass
- try:
- from dm_control.render.osmesa_renderer import OSMesaContext as _OSMesaRenderer
- except ImportError:
- pass
-
- if _EGLRenderer:
- Renderer = _EGLRenderer
- elif _GLFWRenderer:
- Renderer = _GLFWRenderer
- elif _OSMesaRenderer:
- Renderer = _OSMesaRenderer
- else:
- raise ImportError(
- 'No OpenGL rendering backend could be imported. To use '
- '`dm_control.mujoco` without rendering support, set the '
- '`DISABLE_MUJOCO_RENDERING` environment variable before launching your '
- 'interpreter.')
diff --git a/dm_control/render/base.py b/dm_control/render/base.py
deleted file mode 100644
index db2c1fd1..00000000
--- a/dm_control/render/base.py
+++ /dev/null
@@ -1,252 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Base class for OpenGL context handlers.
-
-The module lays foundation for defining various rendering contexts in a uniform
-manner.
-
-ContextBase defines a common interface rendering contexts should fulfill. In
-addition, it provides a context activation method that can be used in 'with'
-statements to ensure symmetrical context activation and deactivation.
-
-The problem of optimizing context swaps falls to ContextPolicyManager and the
-accompanying policy classes. OptimizedContextPolicy will attempt to reduce
-the number of context swaps, increasing application's performance.
-DebugContextPolicy, on the other, hand will rigorously keep activating and
-deactivating contexts for each request, providing a reliable framework for
-functional tests of the new context implementations.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import contextlib
-import threading
-
-# Internal dependencies.
-import six
-
-_ACTIVE_CONTEXT_PARAM = '_active_context'
-
-# A storage for thread local data.
-_thread_local_data = threading.local()
-
-
-class _ContextPolicyManager(object):
- """Manages a context switching policy."""
-
- def __init__(self):
- """Instance initializer."""
- self._policy = None
- self.enable_debug_mode(False)
-
- def enable_debug_mode(self, flag):
- """Enables/disables a debug context management policy.
-
- For details, please see DebugContextPolicy docstring.
-
- Args:
- flag: A boolean value.
- """
- if flag:
- self._policy = _DebugContextPolicy()
- else:
- self._policy = _OptimizedContextPolicy()
-
- def activate(self, context, width, height):
- """Forwards the call to policy method that handles context activation.
-
- Args:
- context: Render context to activate, an instance of ContextBase.
- width: Integer specifying the new framebuffer width in pixels.
- height: Integer specifying the new framebuffer height in pixels.
- """
- self._policy.activate(context, width, height)
-
- def deactivate(self, context):
- """Forwards the call to policy method that handles context deactivation.
-
- Args:
- context: Render context to deactivate, an instance of ContextBase.
- """
- self._policy.deactivate(context)
-
- def release_context(self, context):
- """Forwards the call to policy method that handles context tracking.
-
- Args:
- context: Render context to deactivate, an instance of ContextBase.
- """
- self._policy.release_context(context)
-
-
-class _OptimizedContextPolicy(object):
- """Context management policy that performs lazy context activation.
-
- It performs context activations only when the context or the viewport size
- change. If an application uses only a single context with a fixed-size
- viewport, the policy will have it activated only once.
-
- Moreover, the policy makes sure that each context is activated and then used
- from the same thread of execution.
- """
-
- def __init__(self):
- """Instance initializer."""
- self._context_stamp = (0, -1, -1)
-
- def activate(self, context, width, height):
- """Performs a lazy context activation.
-
- Checks if the context has changed since the last call, and if it has, it
- proceeds with the activation procedure.
- Activation consists of deactivating the previously active context, if any,
- and then activating the new context.
-
- Args:
- context: Render context to activate, an instance of ContextBase.
- width: Integer specifying the new framebuffer width in pixels.
- height: Integer specifying the new framebuffer height in pixels.
- """
- context_stamp = (id(context), width, height)
- if self._context_stamp == context_stamp:
- return
- else:
- if self._active_context:
- self._active_context.deactivate()
- self._active_context = context
- self._context_stamp = context_stamp
- if context:
- context.activate(width, height)
-
- def deactivate(self, context):
- """Performs a lazy context deactivation.
-
- Actual deactivation is deferred to the activation procedure.
-
- Args:
- context: Render context to deactivate, an instance of ContextBase.
- """
- pass
-
- def release_context(self, context):
- """Stops tracking the specified context, releasing references to it.
-
- Args:
- context: Render context to deactivate, an instance of ContextBase.
- """
- if self._active_context is context:
- self._active_context = None
-
- @property
- def _active_context(self):
- value = getattr(_thread_local_data, _ACTIVE_CONTEXT_PARAM, None)
- return value
-
- @_active_context.setter
- def _active_context(self, value):
- setattr(_thread_local_data, _ACTIVE_CONTEXT_PARAM, value)
-
-
-class _DebugContextPolicy(object):
- """Context management policy used for debugging rendering problems.
-
- It always activates and then symmetrically deactivates the rendering context,
- for every 'make_current' call made.
- """
-
- def activate(self, context, width, height):
- """Activates the specified context.
-
- Args:
- context: Render context to activate, an instance of ContextBase.
- width: Integer specifying the new framebuffer width in pixels.
- height: Integer specifying the new framebuffer height in pixels.
- """
- context.activate(width, height)
-
- def deactivate(self, context):
- """Deactivates the specified context.
-
- Args:
- context: Render context to deactivate, an instance of ContextBase.
- """
- context.deactivate()
-
- def release_context(self, context):
- """The call is ignored by this policy.
-
- Args:
- context: Render context to deactivate, an instance of ContextBase.
- """
- pass
-
-
-# A singleton instance of the context policy manager.
-policy_manager = _ContextPolicyManager()
-
-
-@six.add_metaclass(abc.ABCMeta)
-class ContextBase(object):
- """Base class for managing OpenGL contexts."""
-
- def __init__(self):
- """Initializes this context."""
-
- @abc.abstractmethod
- def activate(self, width, height):
- """Called when entering the `make_current` context manager.
-
- Args:
- width: Integer specifying the new framebuffer width in pixels.
- height: Integer specifying the new framebuffer height in pixels.
- """
-
- @abc.abstractmethod
- def deactivate(self):
- """Called when exiting the `make_current` context manager."""
-
- @abc.abstractmethod
- def _free(self):
- """Performs an implementation specific context cleanup."""
-
- def free(self):
- """Frees resources associated with this context."""
- policy_manager.release_context(self)
- self._free()
-
- def __del__(self):
- self.free()
-
- @contextlib.contextmanager
- def make_current(self, width, height):
- """Context manager that makes this Renderer's OpenGL context current.
-
- Args:
- width: Integer specifying the new framebuffer width in pixels.
- height: Integer specifying the new framebuffer height in pixels.
-
- Yields:
- None
- """
- policy_manager.activate(self, width, height)
- try:
- yield
- finally:
- policy_manager.deactivate(self)
-
diff --git a/dm_control/render/base_test.py b/dm_control/render/base_test.py
deleted file mode 100644
index 342d7309..00000000
--- a/dm_control/render/base_test.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Tests for the base rendering module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import threading
-# Internal dependencies.
-from absl.testing import absltest
-from dm_control.render import base
-import mock
-import six
-
-WIDTH = 1024
-HEIGHT = 768
-
-
-class ContextBaseTests(absltest.TestCase):
-
- class ContextMock(base.ContextBase):
-
- def __init__(self):
- super(ContextBaseTests.ContextMock, self).__init__()
-
- def activate(self, width, height):
- pass
-
- def deactivate(self):
- pass
-
- def _free(self):
- pass
-
- def setUp(self):
- self.original_manager = base.policy_manager
-
- base.policy_manager = mock.MagicMock()
- self.context = ContextBaseTests.ContextMock()
- self.context._policy = base.policy_manager
-
- def tearDown(self):
- base.policy_manager = self.original_manager
-
- def test_activating_context(self):
- with self.context.make_current(WIDTH, HEIGHT):
- base.policy_manager.activate.assert_called_once_with(
- self.context, WIDTH, HEIGHT)
- base.policy_manager.deactivate.assert_called_once_with(self.context)
-
-
-class ContextPolicyManagerTests(absltest.TestCase):
-
- def setUp(self):
- self.context = mock.MagicMock()
- self.policy = mock.MagicMock()
- base.policy_manager._policy = self.policy
-
- def test_activation(self):
- base.policy_manager.activate(self.context, WIDTH, HEIGHT)
- self.policy.activate.assert_called_once_with(self.context, WIDTH, HEIGHT)
-
- def test_deactivation(self):
- base.policy_manager.deactivate(self.context)
- self.policy.deactivate.assert_called_once_with(self.context)
-
- def test_selecting_policy(self):
- base.policy_manager.enable_debug_mode(True)
- self.assertIsInstance(
- base.policy_manager._policy, base._DebugContextPolicy)
- base.policy_manager.enable_debug_mode(False)
- self.assertIsInstance(
- base.policy_manager._policy, base._OptimizedContextPolicy)
-
-
-class OptimizedContextPolicyTests(absltest.TestCase):
-
- def setUp(self):
- self.policy = base._OptimizedContextPolicy()
-
- def test_activating_same_context_multiple_times(self):
- context = mock.MagicMock(spec=base.ContextBase)
- for _ in six.moves.xrange(3):
- self.policy.activate(context, WIDTH, HEIGHT)
- self.policy.deactivate(context)
- context.activate.assert_called_once_with(WIDTH, HEIGHT)
- self.assertEqual(0, context.deactivate.call_count)
-
- def test_switching_contexts(self):
- contexts = [mock.MagicMock(spec=base.ContextBase)
- for _ in six.moves.xrange(3)]
- for context in contexts:
- self.policy.activate(context, WIDTH, HEIGHT)
- self.policy.deactivate(context)
- self.policy.activate(None, WIDTH, HEIGHT)
- for context in contexts:
- context.activate.assert_called_once_with(WIDTH, HEIGHT)
- context.deactivate.assert_called_once()
-
- def test_context_are_tracked_separately_for_each_thread(self):
- parent_context = mock.MagicMock(spec=base.ContextBase)
- child_context = mock.MagicMock(spec=base.ContextBase)
-
- def run():
- # Record the context that was active on this thread prior to activation
- # call.
- self.child_thread_context_before = self.policy._active_context
-
- # Activate and record the activated context.
- self.policy.activate(child_context, WIDTH, HEIGHT)
- self.child_thread_context_after = self.policy._active_context
-
- thread = threading.Thread(target=run)
-
- # Main thread activates 'parent_context'
- self.policy.activate(parent_context, WIDTH, HEIGHT)
- self.assertEqual(parent_context, self.policy._active_context)
-
- # The child thread activates 'child_context'
- thread.start()
- thread.join()
-
- # Activation from separate threads shouldn't affect one another
- self.assertIsNone(self.child_thread_context_before)
- self.assertEqual(parent_context, self.policy._active_context)
- self.assertEqual(child_context, self.child_thread_context_after)
-
-
-class DebugContextPolicyTests(absltest.TestCase):
-
- def setUp(self):
- self.policy = base._DebugContextPolicy()
-
- def test_activating_same_context_multiple_times(self):
- context = mock.MagicMock()
- for _ in six.moves.xrange(3):
- self.policy.activate(context, WIDTH, HEIGHT)
- self.policy.deactivate(context)
- self.assertEqual(3, context.activate.call_count)
- self.assertEqual(3, context.deactivate.call_count)
-
- def test_switching_contexts(self):
- contexts = [mock.MagicMock() for _ in six.moves.xrange(3)]
- for context in contexts:
- self.policy.activate(context, WIDTH, HEIGHT)
- self.policy.deactivate(context)
- for context in contexts:
- context.activate.assert_called_once_with(WIDTH, HEIGHT)
- context.deactivate.assert_called_once()
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/dm_control/render/glfw_renderer_test.py b/dm_control/render/glfw_renderer_test.py
deleted file mode 100644
index 6db3b54b..00000000
--- a/dm_control/render/glfw_renderer_test.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Tests for GLFWContext."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import unittest
-
-# Internal dependencies.
-from absl.testing import absltest
-from dm_control import render
-import mock
-
-MAX_WIDTH = 1024
-MAX_HEIGHT = 1024
-CONTEXT_PATH = render.__name__ + '.glfw_renderer.glfw'
-
-
-@unittest.skipUnless(render._GLFWRenderer,
- reason='GLFW renderer could not be imported.')
-@mock.patch(CONTEXT_PATH)
-class GLFWContextTest(absltest.TestCase):
-
- def setUp(self):
- self.context = mock.MagicMock()
-
- with mock.patch(CONTEXT_PATH):
- self.renderer = render.Renderer(MAX_WIDTH, MAX_HEIGHT)
-
- def tearDown(self):
- self.renderer._context = None
-
- def test_activation(self, mock_glfw):
- self.renderer.activate(MAX_WIDTH, MAX_HEIGHT)
- mock_glfw.make_context_current.assert_called_once()
-
- def test_deactivation(self, mock_glfw):
- self.renderer.deactivate()
- mock_glfw.make_context_current.assert_called_once()
-
- def test_freeing(self, mock_glfw):
- self.renderer._context = mock.MagicMock()
- self.renderer._previous_context = mock.MagicMock()
- self.renderer.free()
- mock_glfw.destroy_window.assert_called_once()
- self.assertIsNone(self.renderer._context)
- self.assertIsNone(self.renderer._previous_context)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/dm_control/rl/control.py b/dm_control/rl/control.py
index 2666ed83..cab00332 100644
--- a/dm_control/rl/control.py
+++ b/dm_control/rl/control.py
@@ -13,29 +13,19 @@
# limitations under the License.
# ============================================================================
-"""An environment.Base subclass for control-specific environments."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+"""A dm_env.Environment subclass for control-specific environments."""
import abc
import collections
import contextlib
-
-# Internal dependencies.
-
+import dm_env
+from dm_env import specs
import numpy as np
-import six
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-from dm_control.rl import environment
-from dm_control.rl import specs
FLAT_OBSERVATION_KEY = 'observations'
-class Environment(environment.Base):
+class Environment(dm_env.Environment):
"""Class for physics-based reinforcement learning environments."""
def __init__(self,
@@ -44,7 +34,8 @@ def __init__(self,
time_limit=float('inf'),
control_timestep=None,
n_sub_steps=None,
- flat_observation=False):
+ flat_observation=False,
+ legacy_step: bool = True):
"""Initializes a new `Environment`.
Args:
@@ -58,13 +49,16 @@ def __init__(self,
`control_timestep` is not specified.
flat_observation: If True, observations will be flattened and concatenated
into a single numpy array.
+ legacy_step: If True, steps the state with up-to-date position and
+ velocity dependent fields. See Page 6 of
+ https://arxiv.org/abs/2006.12983 for more information.
Raises:
ValueError: If both `n_sub_steps` and `control_timestep` are supplied.
"""
self._task = task
self._physics = physics
- self._time_limit = time_limit
+ self._physics.legacy_step = legacy_step
self._flat_observation = flat_observation
if n_sub_steps is not None and control_timestep is not None:
@@ -77,11 +71,18 @@ def __init__(self,
else:
self._n_sub_steps = 1
+ if time_limit == float('inf'):
+ self._step_limit = float('inf')
+ else:
+ self._step_limit = time_limit / (
+ self._physics.timestep() * self._n_sub_steps)
+ self._step_count = 0
self._reset_next_step = True
def reset(self):
"""Starts a new episode and returns the first `TimeStep`."""
self._reset_next_step = False
+ self._step_count = 0
with self._physics.reset_context():
self._task.initialize_episode(self._physics)
@@ -89,8 +90,8 @@ def reset(self):
if self._flat_observation:
observation = flatten_observation(observation)
- return environment.TimeStep(
- step_type=environment.StepType.FIRST,
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.FIRST,
reward=None,
discount=None,
observation=observation)
@@ -102,8 +103,7 @@ def step(self, action):
return self.reset()
self._task.before_step(action, self._physics)
- for _ in xrange(self._n_sub_steps):
- self._physics.step()
+ self._physics.step(self._n_sub_steps)
self._task.after_step(self._physics)
reward = self._task.get_reward(self._physics)
@@ -111,18 +111,20 @@ def step(self, action):
if self._flat_observation:
observation = flatten_observation(observation)
- if self.physics.time() >= self._time_limit:
+ self._step_count += 1
+ if self._step_count >= self._step_limit:
discount = 1.0
else:
discount = self._task.get_termination(self._physics)
- if discount is None:
- return environment.TimeStep(
- environment.StepType.MID, reward, 1.0, observation)
- else:
+ episode_over = discount is not None
+
+ if episode_over:
self._reset_next_step = True
- return environment.TimeStep(
- environment.StepType.LAST, reward, discount, observation)
+ return dm_env.TimeStep(
+ dm_env.StepType.LAST, reward, discount, observation)
+ else:
+ return dm_env.TimeStep(dm_env.StepType.MID, reward, 1.0, observation)
def action_spec(self):
"""Returns the action specification for this environment."""
@@ -194,17 +196,18 @@ def compute_n_steps(control_timestep, physics_timestep, tolerance=1e-8):
def _spec_from_observation(observation):
result = collections.OrderedDict()
- for key, value in six.iteritems(observation):
- result[key] = specs.ArraySpec(value.shape, value.dtype)
+ for key, value in observation.items():
+ result[key] = specs.Array(value.shape, value.dtype, name=key)
return result
# Base class definitions for objects supplied to Environment.
-@six.add_metaclass(abc.ABCMeta)
-class Physics(object):
+class Physics(metaclass=abc.ABCMeta):
"""Simulates a physical environment."""
+ legacy_step: bool = True
+
@abc.abstractmethod
def step(self, n_sub_steps=1):
"""Updates the simulation state.
@@ -242,7 +245,10 @@ def reset_context(self):
Yields:
The `Physics` instance.
"""
- self.reset()
+ try:
+ self.reset()
+ except PhysicsError:
+ pass
yield self
self.after_reset()
@@ -265,8 +271,7 @@ class PhysicsError(RuntimeError):
"""Raised if the state of the physics simulation becomes divergent."""
-@six.add_metaclass(abc.ABCMeta)
-class Task(object):
+class Task(metaclass=abc.ABCMeta):
"""Defines a task in a `control.Environment`."""
@abc.abstractmethod
@@ -330,7 +335,7 @@ def step_spec(self, physics):
that describe the shapes, dtypes and elementwise lower and upper bounds
for the array(s) returned by `self.step`.
"""
- raise NotImplementedError
+ raise NotImplementedError()
@abc.abstractmethod
def get_observation(self, physics):
@@ -379,16 +384,16 @@ def flatten_observation(observation, output_key=FLAT_OBSERVATION_KEY):
and concatenated observation array.
Raises:
- ValueError: If `observation` is not a `collections.MutableMapping`.
+ ValueError: If `observation` is not a `collections.abc.MutableMapping`.
"""
- if not isinstance(observation, collections.MutableMapping):
+ if not isinstance(observation, collections.abc.MutableMapping):
raise ValueError('Can only flatten dict-like observations.')
if isinstance(observation, collections.OrderedDict):
- keys = six.iterkeys(observation)
+ keys = observation.keys()
else:
# Keep a consistent ordering for other mappings.
- keys = sorted(six.iterkeys(observation))
+ keys = sorted(observation.keys())
observation_arrays = [observation[key].ravel() for key in keys]
return type(observation)([(output_key, np.concatenate(observation_arrays))])
diff --git a/dm_control/rl/control_test.py b/dm_control/rl/control_test.py
index f7def19f..9e109a04 100644
--- a/dm_control/rl/control_test.py
+++ b/dm_control/rl/control_test.py
@@ -15,33 +15,25 @@
"""Control Environment tests."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
from absl.testing import absltest
from absl.testing import parameterized
-
from dm_control.rl import control
-
+from dm_env import specs
import mock
import numpy as np
-from dm_control.rl import specs
-
_CONSTANT_REWARD_VALUE = 1.0
_CONSTANT_OBSERVATION = {'observations': np.asarray(_CONSTANT_REWARD_VALUE)}
-_ACTION_SPEC = specs.BoundedArraySpec(
- shape=(1,), dtype=np.float, minimum=0.0, maximum=1.0)
-_OBSERVATION_SPEC = {'observations': specs.ArraySpec(shape=(), dtype=np.float)}
+_ACTION_SPEC = specs.BoundedArray(
+ shape=(1,), dtype=float, minimum=0.0, maximum=1.0)
+_OBSERVATION_SPEC = {'observations': specs.Array(shape=(), dtype=float)}
class EnvironmentTest(parameterized.TestCase):
def setUp(self):
+ super().setUp()
self._task = mock.Mock(spec=control.Task)
self._task.initialize_episode = mock.Mock()
self._task.get_observation = mock.Mock(return_value=_CONSTANT_OBSERVATION)
@@ -72,14 +64,27 @@ def test_environment_calls(self):
self._task.after_step.assert_called_with(self._physics)
self._task.get_termination.assert_called_with(self._physics)
- self.assertEquals(_CONSTANT_REWARD_VALUE, time_step.reward)
+ self.assertEqual(_CONSTANT_REWARD_VALUE, time_step.reward)
- def test_timeout(self):
- self._physics.time = mock.Mock(return_value=2.)
+ @parameterized.parameters(
+ {'physics_timestep': .01, 'control_timestep': None,
+ 'expected_steps': 1000},
+ {'physics_timestep': .01, 'control_timestep': .05,
+ 'expected_steps': 5000})
+ def test_timeout(self, expected_steps, physics_timestep, control_timestep):
+ self._physics.timestep.return_value = physics_timestep
+ time_limit = expected_steps * (control_timestep or physics_timestep)
env = control.Environment(
- physics=self._physics, task=self._task, time_limit=1.)
- env.reset()
- time_step = env.step([1])
+ physics=self._physics, task=self._task, time_limit=time_limit,
+ control_timestep=control_timestep)
+
+ time_step = env.reset()
+ steps = 0
+ while not time_step.last():
+ time_step = env.step([1])
+ steps += 1
+
+ self.assertEqual(steps, expected_steps)
self.assertTrue(time_step.last())
time_step = env.step([1])
@@ -102,12 +107,12 @@ def test_control_timestep(self):
def test_flatten_observations(self):
multimodal_obs = dict(_CONSTANT_OBSERVATION)
- multimodal_obs['sensor'] = np.zeros(7, dtype=np.bool)
+ multimodal_obs['sensor'] = np.zeros(7, dtype=bool)
self._task.get_observation = mock.Mock(return_value=multimodal_obs)
env = control.Environment(
physics=self._physics, task=self._task, flat_observation=True)
timestep = env.reset()
- self.assertEqual(len(timestep.observation), 1)
+ self.assertLen(timestep.observation, 1)
self.assertEqual(timestep.observation[control.FLAT_OBSERVATION_KEY].size,
1 + 7)
@@ -118,7 +123,7 @@ class ComputeNStepsTest(parameterized.TestCase):
(0.03, 0.005, 6))
def testComputeNSteps(self, control_timestep, physics_timestep, expected):
steps = control.compute_n_steps(control_timestep, physics_timestep)
- self.assertEquals(expected, steps)
+ self.assertEqual(expected, steps)
@parameterized.parameters((3, 2), (.003, .00101))
def testComputeNStepsFailures(self, control_timestep, physics_timestep):
diff --git a/dm_control/rl/environment.py b/dm_control/rl/environment.py
deleted file mode 100644
index d6e2f7f9..00000000
--- a/dm_control/rl/environment.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Python RL Environment API."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import collections
-
-# Internal dependencies.
-
-import enum
-import six
-
-
-class TimeStep(collections.namedtuple(
- 'TimeStep', ['step_type', 'reward', 'discount', 'observation'])):
- """Returned with every call to `step` and `reset` on an environment.
-
- A `TimeStep` contains the data emitted by an environment at each step of
- interaction. A `TimeStep` holds a `step_type`, an `observation` (typically a
- NumPy array or a dict or list of arrays), and an associated `reward` and
- `discount`.
-
- The first `TimeStep` in a sequence will have `StepType.FIRST`. The final
- `TimeStep` will have `StepType.LAST`. All other `TimeStep`s in a sequence will
- have `StepType.MID.
-
- Attributes:
- step_type: A `StepType` enum value.
- reward: A scalar, or `None` if `step_type` is `StepType.FIRST`, i.e. at the
- start of a sequence.
- discount: A discount value in the range `[0, 1]`, or `None` if `step_type`
- is `StepType.FIRST`, i.e. at the start of a sequence.
- observation: A NumPy array, or a nested dict, list or tuple of arrays.
- """
- __slots__ = ()
-
- def first(self):
- return self.step_type is StepType.FIRST
-
- def mid(self):
- return self.step_type is StepType.MID
-
- def last(self):
- return self.step_type is StepType.LAST
-
-
-class StepType(enum.IntEnum):
- """Defines the status of a `TimeStep` within a sequence."""
- # Denotes the first `TimeStep` in a sequence.
- FIRST = 0
- # Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
- MID = 1
- # Denotes the last `TimeStep` in a sequence.
- LAST = 2
-
- def first(self):
- return self is StepType.FIRST
-
- def mid(self):
- return self is StepType.MID
-
- def last(self):
- return self is StepType.LAST
-
-
-@six.add_metaclass(abc.ABCMeta)
-class Base(object):
- """Abstract base class for Python RL environments.
-
- Observations and valid actions are described with `ArraySpec`s, defined in
- the `specs` module.
- """
-
- @abc.abstractmethod
- def reset(self):
- """Starts a new sequence and returns the first `TimeStep` of this sequence.
-
- Returns:
- A `TimeStep` namedtuple containing:
- step_type: A `StepType` of `FIRST`.
- reward: `None`, indicating the reward is undefined.
- discount: `None`, indicating the discount is undefined.
- observation: A NumPy array, or a nested dict, list or tuple of arrays
- corresponding to `observation_spec()`.
- """
-
- @abc.abstractmethod
- def step(self, action):
- """Updates the environment according to the action and returns a `TimeStep`.
-
- If the environment returned a `TimeStep` with `StepType.LAST` at the
- previous step, this call to `step` will start a new sequence and `action`
- will be ignored.
-
- This method will also start a new sequence if called after the environment
- has been constructed and `reset` has not been called. Again, in this case
- `action` will be ignored.
-
- Args:
- action: A NumPy array, or a nested dict, list or tuple of arrays
- corresponding to `action_spec()`.
-
- Returns:
- A `TimeStep` namedtuple containing:
- step_type: A `StepType` value.
- reward: Reward at this timestep, or None if step_type is
- `StepType.FIRST`.
- discount: A discount in the range [0, 1], or None if step_type is
- `StepType.FIRST`.
- observation: A NumPy array, or a nested dict, list or tuple of arrays
- corresponding to `observation_spec()`.
- """
-
- @abc.abstractmethod
- def observation_spec(self):
- """Defines the observations provided by the environment.
-
- May use a subclass of `ArraySpec` that specifies additional properties such
- as min and max bounds on the values.
-
- Returns:
- An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
- """
-
- @abc.abstractmethod
- def action_spec(self):
- """Defines the actions that should be provided to `step`.
-
- May use a subclass of `ArraySpec` that specifies additional properties such
- as min and max bounds on the values.
-
- Returns:
- An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
- """
-
- def step_spec(self):
- """Optional method that defines fields returned by `step`.
-
- Implement this method to define an environment that uses non-standard values
- for any of the items returned by `step`. For example, an environment with
- array-valued rewards.
-
- Returns:
- A `TimeStep` namedtuple containing (possibly nested) `ArraySpec`s defining
- the reward, discount, and observation structure.
- """
- raise NotImplementedError
-
- def close(self):
- """Frees any resources used by the environment.
-
- Implement this method for an environment backed by an external process.
-
- This method be used directly
-
- ```python
- env = Env(...)
- # Use env.
- env.close()
- ```
-
- or via a context manager
-
- ```python
- with Env(...) as env:
- # Use env.
- ```
- """
- pass
-
- def __enter__(self):
- """Allows the environment to be used in a with-statement context."""
- return self
-
- def __exit__(self, unused_exception_type, unused_exc_value, unused_traceback):
- """Allows the environment to be used in a with-statement context."""
- self.close()
-
-# Helper functions for creating TimeStep namedtuples with default settings.
-
-
-def restart(observation):
- """Returns a `TimeStep` with `step_type` set to `StepType.FIRST`."""
- return TimeStep(StepType.FIRST, None, None, observation)
-
-
-def transition(reward, observation, discount=1.0):
- """Returns a `TimeStep` with `step_type` set to `StepType.MID`."""
- return TimeStep(StepType.MID, reward, discount, observation)
-
-
-def termination(reward, observation):
- """Returns a `TimeStep` with `step_type` set to `StepType.LAST`."""
- return TimeStep(StepType.LAST, reward, 0.0, observation)
-
-
-def truncation(reward, observation, discount=1.0):
- """Returns a `TimeStep` with `step_type` set to `StepType.LAST`."""
- return TimeStep(StepType.LAST, reward, discount, observation)
diff --git a/dm_control/rl/specs.py b/dm_control/rl/specs.py
deleted file mode 100644
index 4e52dc39..00000000
--- a/dm_control/rl/specs.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Classes that describe the shape and dtype of numpy arrays."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
-import numpy as np
-
-
-class ArraySpec(object):
- """Describes a numpy array or scalar shape and dtype.
-
- An `ArraySpec` allows an API to describe the arrays that it accepts or
- returns, before that array exists.
- The equivalent version describing a `tf.Tensor` is `TensorSpec`.
- """
- __slots__ = ('_shape', '_dtype', '_name')
-
- def __init__(self, shape, dtype, name=None):
- """Initializes a new `ArraySpec`.
-
- Args:
- shape: An iterable specifying the array shape.
- dtype: numpy dtype or string specifying the array dtype.
- name: Optional string containing a semantic name for the corresponding
- array. Defaults to `None`.
-
- Raises:
- TypeError: If the shape is not an iterable or if the `dtype` is an invalid
- numpy dtype.
- """
- self._shape = tuple(shape)
- self._dtype = np.dtype(dtype)
- self._name = name
-
- @property
- def shape(self):
- """Returns a `tuple` specifying the array shape."""
- return self._shape
-
- @property
- def dtype(self):
- """Returns a numpy dtype specifying the array dtype."""
- return self._dtype
-
- @property
- def name(self):
- """Returns the name of the ArraySpec."""
- return self._name
-
- def __repr__(self):
- return 'ArraySpec(shape={}, dtype={}, name={})'.format(self.shape,
- repr(self.dtype),
- repr(self.name))
-
- def __eq__(self, other):
- """Checks if the shape and dtype of two specs are equal."""
- if not isinstance(other, ArraySpec):
- return False
- return self.shape == other.shape and self.dtype == other.dtype
-
- def __ne__(self, other):
- return not self == other
-
- def _fail_validation(self, message, *args):
- message %= args
- if self.name:
- message += ' for spec %s' % self.name
- raise ValueError(message)
-
- def validate(self, value):
- """Checks if value conforms to this spec.
-
- Args:
- value: a numpy array or value convertible to one via `np.asarray`.
-
- Returns:
- value, converted if necessary to a numpy array.
-
- Raises:
- ValueError: if value doesn't conform to this spec.
- """
- value = np.asarray(value)
- if value.shape != self.shape:
- self._fail_validation(
- 'Expected shape %r but found %r', self.shape, value.shape)
- if value.dtype != self.dtype:
- self._fail_validation(
- 'Expected dtype %s but found %s', self.dtype, value.dtype)
-
- def generate_value(self):
- """Generate a test value which conforms to this spec."""
- return np.zeros(shape=self.shape, dtype=self.dtype)
-
-
-class BoundedArraySpec(ArraySpec):
- """An `ArraySpec` that specifies minimum and maximum values.
-
- Example usage:
- ```python
- # Specifying the same minimum and maximum for every element.
- spec = BoundedArraySpec((3, 4), np.float64, minimum=0.0, maximum=1.0)
-
- # Specifying a different minimum and maximum for each element.
- spec = BoundedArraySpec(
- (2,), np.float64, minimum=[0.1, 0.2], maximum=[0.9, 0.9])
-
- # Specifying the same minimum and a different maximum for each element.
- spec = BoundedArraySpec(
- (3,), np.float64, minimum=-10.0, maximum=[4.0, 5.0, 3.0])
- ```
-
- Bounds are meant to be inclusive. This is especially important for
- integer types. The following spec will be satisfied by arrays
- with values in the set {0, 1, 2}:
- ```python
- spec = BoundedArraySpec((3, 4), np.int, minimum=0, maximum=2)
- ```
- """
-
- __slots__ = ('_minimum', '_maximum')
-
- def __init__(self, shape, dtype, minimum, maximum, name=None):
- """Initializes a new `BoundedArraySpec`.
-
- Args:
- shape: An iterable specifying the array shape.
- dtype: numpy dtype or string specifying the array dtype.
- minimum: Number or sequence specifying the maximum element bounds
- (inclusive). Must be broadcastable to `shape`.
- maximum: Number or sequence specifying the maximum element bounds
- (inclusive). Must be broadcastable to `shape`.
- name: Optional string containing a semantic name for the corresponding
- array. Defaults to `None`.
-
- Raises:
- ValueError: If `minimum` or `maximum` are not broadcastable to `shape`.
- TypeError: If the shape is not an iterable or if the `dtype` is an invalid
- numpy dtype.
- """
- super(BoundedArraySpec, self).__init__(shape, dtype, name)
-
- try:
- np.broadcast_to(minimum, shape=shape)
- except ValueError as numpy_exception:
- raise ValueError('minimum is not compatible with shape. '
- 'Message: {!r}.'.format(numpy_exception))
-
- try:
- np.broadcast_to(maximum, shape=shape)
- except ValueError as numpy_exception:
- raise ValueError('maximum is not compatible with shape. '
- 'Message: {!r}.'.format(numpy_exception))
-
- self._minimum = np.array(minimum)
- self._minimum.setflags(write=False)
-
- self._maximum = np.array(maximum)
- self._maximum.setflags(write=False)
-
- @property
- def minimum(self):
- """Returns a NumPy array specifying the minimum bounds (inclusive)."""
- return self._minimum
-
- @property
- def maximum(self):
- """Returns a NumPy array specifying the maximum bounds (inclusive)."""
- return self._maximum
-
- def __repr__(self):
- template = ('BoundedArraySpec(shape={}, dtype={}, name={}, '
- 'minimum={}, maximum={})')
- return template.format(self.shape, repr(self.dtype), repr(self.name),
- self._minimum, self._maximum)
-
- def __eq__(self, other):
- if not isinstance(other, BoundedArraySpec):
- return False
- return (super(BoundedArraySpec, self).__eq__(other) and
- (self.minimum == other.minimum).all() and
- (self.maximum == other.maximum).all())
-
- def validate(self, value):
- value = np.asarray(value)
- super(BoundedArraySpec, self).validate(value)
- if (value < self.minimum).any() or (value > self.maximum).any():
- self._fail_validation(
- 'Values were not all within bounds %s <= value <= %s',
- self.minimum, self.maximum)
-
- def generate_value(self):
- return (np.ones(shape=self.shape, dtype=self.dtype) *
- self.dtype.type(self.minimum))
diff --git a/dm_control/rl/specs_test.py b/dm_control/rl/specs_test.py
deleted file mode 100644
index 1b3feaaf..00000000
--- a/dm_control/rl/specs_test.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-"""Tests for specs."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
-from absl.testing import absltest
-from dm_control.rl import specs as array_spec
-import numpy as np
-
-
-class ArraySpecTest(absltest.TestCase):
-
- def testShapeTypeError(self):
- with self.assertRaises(TypeError):
- array_spec.ArraySpec(32, np.int32)
-
- def testDtypeTypeError(self):
- with self.assertRaises(TypeError):
- array_spec.ArraySpec((1, 2, 3), "32")
-
- def testStringDtype(self):
- array_spec.ArraySpec((1, 2, 3), "int32")
-
- def testNumpyDtype(self):
- array_spec.ArraySpec((1, 2, 3), np.int32)
-
- def testDtype(self):
- spec = array_spec.ArraySpec((1, 2, 3), np.int32)
- self.assertEqual(np.int32, spec.dtype)
-
- def testShape(self):
- spec = array_spec.ArraySpec([1, 2, 3], np.int32)
- self.assertEqual((1, 2, 3), spec.shape)
-
- def testEqual(self):
- spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
- spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
- self.assertEqual(spec_1, spec_2)
-
- def testNotEqualDifferentShape(self):
- spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
- spec_2 = array_spec.ArraySpec((1, 3, 3), np.int32)
- self.assertNotEqual(spec_1, spec_2)
-
- def testNotEqualDifferentDtype(self):
- spec_1 = array_spec.ArraySpec((1, 2, 3), np.int64)
- spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
- self.assertNotEqual(spec_1, spec_2)
-
- def testNotEqualOtherClass(self):
- spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
- spec_2 = None
- self.assertNotEqual(spec_1, spec_2)
- self.assertNotEqual(spec_2, spec_1)
-
- spec_2 = ()
- self.assertNotEqual(spec_1, spec_2)
- self.assertNotEqual(spec_2, spec_1)
-
- def testValidateDtype(self):
- spec = array_spec.ArraySpec((1, 2), np.int32)
- spec.validate(np.zeros((1, 2), dtype=np.int32))
- with self.assertRaises(ValueError):
- spec.validate(np.zeros((1, 2), dtype=np.float32))
-
- def testValidateShape(self):
- spec = array_spec.ArraySpec((1, 2), np.int32)
- spec.validate(np.zeros((1, 2), dtype=np.int32))
- with self.assertRaises(ValueError):
- spec.validate(np.zeros((1, 2, 3), dtype=np.int32))
-
- def testGenerateValue(self):
- spec = array_spec.ArraySpec((1, 2), np.int32)
- test_value = spec.generate_value()
- spec.validate(test_value)
-
-
-class BoundedArraySpecTest(absltest.TestCase):
-
- def testInvalidMinimum(self):
- with self.assertRaisesRegexp(ValueError, "not compatible"):
- array_spec.BoundedArraySpec((3, 5), np.uint8, (0, 0, 0), (1, 1))
-
- def testInvalidMaximum(self):
- with self.assertRaisesRegexp(ValueError, "not compatible"):
- array_spec.BoundedArraySpec((3, 5), np.uint8, 0, (1, 1, 1))
-
- def testMinMaxAttributes(self):
- spec = array_spec.BoundedArraySpec((1, 2, 3), np.float32, 0, (5, 5, 5))
- self.assertEqual(type(spec.minimum), np.ndarray)
- self.assertEqual(type(spec.maximum), np.ndarray)
-
- def testNotWriteable(self):
- spec = array_spec.BoundedArraySpec((1, 2, 3), np.float32, 0, (5, 5, 5))
- with self.assertRaisesRegexp(ValueError, "read-only"):
- spec.minimum[0] = -1
- with self.assertRaisesRegexp(ValueError, "read-only"):
- spec.maximum[0] = 100
-
- def testEqualBroadcastingBounds(self):
- spec_1 = array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=0.0, maximum=1.0)
- spec_2 = array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0])
- self.assertEqual(spec_1, spec_2)
-
- def testNotEqualDifferentMinimum(self):
- spec_1 = array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=[0.0, -0.6], maximum=[1.0, 1.0])
- spec_2 = array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0])
- self.assertNotEqual(spec_1, spec_2)
-
- def testNotEqualOtherClass(self):
- spec_1 = array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=[0.0, -0.6], maximum=[1.0, 1.0])
- spec_2 = array_spec.ArraySpec((1, 2), np.int32)
- self.assertNotEqual(spec_1, spec_2)
- self.assertNotEqual(spec_2, spec_1)
-
- spec_2 = None
- self.assertNotEqual(spec_1, spec_2)
- self.assertNotEqual(spec_2, spec_1)
-
- spec_2 = ()
- self.assertNotEqual(spec_1, spec_2)
- self.assertNotEqual(spec_2, spec_1)
-
- def testNotEqualDifferentMaximum(self):
- spec_1 = array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=0.0, maximum=2.0)
- spec_2 = array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0])
- self.assertNotEqual(spec_1, spec_2)
-
- def testRepr(self):
- as_string = repr(array_spec.BoundedArraySpec(
- (1, 2), np.int32, minimum=101.0, maximum=73.0))
- self.assertIn("101", as_string)
- self.assertIn("73", as_string)
-
- def testValidateBounds(self):
- spec = array_spec.BoundedArraySpec((2, 2), np.int32, minimum=5, maximum=10)
- spec.validate(np.array([[5, 6], [8, 10]], dtype=np.int32))
- with self.assertRaises(ValueError):
- spec.validate(np.array([[5, 6], [8, 11]], dtype=np.int32))
- with self.assertRaises(ValueError):
- spec.validate(np.array([[4, 6], [8, 10]], dtype=np.int32))
-
- def testGenerateValue(self):
- spec = array_spec.BoundedArraySpec((2, 2), np.int32, minimum=5, maximum=10)
- test_value = spec.generate_value()
- spec.validate(test_value)
-
- def testScalarBounds(self):
- spec = array_spec.BoundedArraySpec((), np.float, minimum=0.0, maximum=1.0)
-
- self.assertIsInstance(spec.minimum, np.ndarray)
- self.assertIsInstance(spec.maximum, np.ndarray)
-
- # Sanity check that numpy compares correctly to a scalar for an empty shape.
- self.assertEqual(0.0, spec.minimum)
- self.assertEqual(1.0, spec.maximum)
-
- # Check that the spec doesn't fail its own input validation.
- _ = array_spec.BoundedArraySpec(
- spec.shape, spec.dtype, spec.minimum, spec.maximum)
-
-
-if __name__ == "__main__":
- absltest.main()
diff --git a/dm_control/suite/README.md b/dm_control/suite/README.md
index 9bfdbe83..d7dca5c5 100644
--- a/dm_control/suite/README.md
+++ b/dm_control/suite/README.md
@@ -1,4 +1,58 @@
# DeepMind Control Suite.
-This directory contains the domains and tasks described in the
+This submodule contains the domains and tasks described in the
[DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690).
+
+# 
+
+## Quickstart
+
+```python
+from dm_control import suite
+import numpy as np
+
+# Load one task:
+env = suite.load(domain_name="cartpole", task_name="swingup")
+
+# Iterate over a task set:
+for domain_name, task_name in suite.BENCHMARKING:
+ env = suite.load(domain_name, task_name)
+
+# Step through an episode and print out reward, discount and observation.
+action_spec = env.action_spec()
+time_step = env.reset()
+while not time_step.last():
+ action = np.random.uniform(action_spec.minimum,
+ action_spec.maximum,
+ size=action_spec.shape)
+ time_step = env.step(action)
+ print(time_step.reward, time_step.discount, time_step.observation)
+```
+
+## Illustration video
+
+Below is a video montage of solved Control Suite tasks, with reward
+visualisation enabled.
+
+[](https://www.youtube.com/watch?v=rAai4QzcYbs)
+
+
+### Quadruped domain [April 2019]
+
+Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are:
+
+- 4 DoFs per leg, 1 constraining tendon.
+- 3 actuators per leg: 'yaw', 'lift', 'extend'.
+- Filtered position actuators with timescale of 100ms.
+- Sensors include an IMU, force/torque sensors, and rangefinders.
+
+Four tasks:
+
+- `walk` and `run`: self-right the body then move forward at a desired speed.
+- `escape`: escape a bowl-shaped random terrain (uses rangefinders).
+- `fetch`, go to a moving ball and bring it to a target.
+
+All behaviors in the video below were trained with [Abdolmaleki et al's
+MPO](https://arxiv.org/abs/1806.06920).
+
+[](https://www.youtube.com/watch?v=RhRLjbb7pBE)
diff --git a/dm_control/suite/__init__.py b/dm_control/suite/__init__.py
index 021a4be9..e0b6c796 100644
--- a/dm_control/suite/__init__.py
+++ b/dm_control/suite/__init__.py
@@ -15,10 +15,6 @@
"""A collection of MuJoCo-based Reinforcement Learning environments."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
import inspect
import itertools
@@ -29,6 +25,7 @@
from dm_control.suite import ball_in_cup
from dm_control.suite import cartpole
from dm_control.suite import cheetah
+from dm_control.suite import dog
from dm_control.suite import finger
from dm_control.suite import fish
from dm_control.suite import hopper
@@ -38,6 +35,7 @@
from dm_control.suite import manipulator
from dm_control.suite import pendulum
from dm_control.suite import point_mass
+from dm_control.suite import quadruped
from dm_control.suite import reacher
from dm_control.suite import stacker
from dm_control.suite import swimmer
@@ -53,6 +51,7 @@ def _get_tasks(tag):
result = []
for domain_name in sorted(_DOMAINS.keys()):
+
domain = _DOMAINS[domain_name]
if tag is None:
@@ -84,12 +83,15 @@ def _get_tasks_by_domain(tasks):
EASY = _get_tasks('easy')
HARD = _get_tasks('hard')
EXTRA = tuple(sorted(set(ALL_TASKS) - set(BENCHMARKING)))
+NO_REWARD_VIZ = _get_tasks('no_reward_visualization')
+REWARD_VIZ = tuple(sorted(set(ALL_TASKS) - set(NO_REWARD_VIZ)))
# A mapping from each domain name to a sequence of its task names.
TASKS_BY_DOMAIN = _get_tasks_by_domain(ALL_TASKS)
-def load(domain_name, task_name, task_kwargs=None, visualize_reward=False):
+def load(domain_name, task_name, task_kwargs=None, environment_kwargs=None,
+ visualize_reward=False):
"""Returns an environment from a domain name, task name and optional settings.
```python
@@ -100,6 +102,8 @@ def load(domain_name, task_name, task_kwargs=None, visualize_reward=False):
domain_name: A string containing the name of a domain.
task_name: A string containing the name of a task.
task_kwargs: Optional `dict` of keyword arguments for the task.
+ environment_kwargs: Optional `dict` specifying keyword arguments for the
+ environment.
visualize_reward: Optional `bool`. If `True`, object colours in rendered
frames are set to indicate the reward at each step. Default `False`.
@@ -107,17 +111,19 @@ def load(domain_name, task_name, task_kwargs=None, visualize_reward=False):
The requested environment.
"""
return build_environment(domain_name, task_name, task_kwargs,
- visualize_reward)
+ environment_kwargs, visualize_reward)
def build_environment(domain_name, task_name, task_kwargs=None,
- visualize_reward=False):
+ environment_kwargs=None, visualize_reward=False):
"""Returns an environment from the suite given a domain name and a task name.
Args:
domain_name: A string containing the name of a domain.
task_name: A string containing the name of a task.
task_kwargs: Optional `dict` specifying keyword arguments for the task.
+ environment_kwargs: Optional `dict` specifying keyword arguments for the
+ environment.
visualize_reward: Optional `bool`. If `True`, object colours in rendered
frames are set to indicate the reward at each step. Default `False`.
@@ -137,6 +143,8 @@ def build_environment(domain_name, task_name, task_kwargs=None,
task_name, domain_name))
task_kwargs = task_kwargs or {}
+ if environment_kwargs is not None:
+ task_kwargs = dict(task_kwargs, environment_kwargs=environment_kwargs)
env = domain.SUITE[task_name](**task_kwargs)
env.task.visualize_reward = visualize_reward
return env
diff --git a/dm_control/suite/acrobot.py b/dm_control/suite/acrobot.py
index 72fe30cb..900e8cb0 100644
--- a/dm_control/suite/acrobot.py
+++ b/dm_control/suite/acrobot.py
@@ -15,21 +15,14 @@
"""Acrobot domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
_DEFAULT_TIME_LIMIT = 10
@@ -42,19 +35,25 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns Acrobot balance task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Balance(sparse=False, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add('benchmarking')
-def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns Acrobot sparse balance."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Balance(sparse=True, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
class Physics(mujoco.Physics):
@@ -92,7 +91,7 @@ def __init__(self, sparse, random=None):
automatically (default).
"""
self._sparse = sparse
- super(Balance, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode.
@@ -104,6 +103,7 @@ def initialize_episode(self, physics):
"""
physics.named.data.qpos[
['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of pole orientation and angular velocities."""
diff --git a/dm_control/suite/acrobot.xml b/dm_control/suite/acrobot.xml
index 6d05fe35..79b76d9c 100644
--- a/dm_control/suite/acrobot.xml
+++ b/dm_control/suite/acrobot.xml
@@ -24,7 +24,7 @@ Based on Coulomb's [1] rather than Spong's [2] model.
-
+
diff --git a/all_domains.png b/dm_control/suite/all_domains.png
similarity index 100%
rename from all_domains.png
rename to dm_control/suite/all_domains.png
diff --git a/dm_control/suite/ball_in_cup.py b/dm_control/suite/ball_in_cup.py
index 2eab2471..33af720b 100644
--- a/dm_control/suite/ball_in_cup.py
+++ b/dm_control/suite/ball_in_cup.py
@@ -15,14 +15,8 @@
"""Ball-in-Cup Domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -42,12 +36,14 @@ def get_model_and_assets():
@SUITE.add('benchmarking', 'easy')
-def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Ball-in-Cup task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = BallInCup(random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -86,6 +82,7 @@ def initialize_episode(self, physics):
# Check for collisions.
physics.after_reset()
penetrating = physics.data.ncon > 0
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of the state."""
diff --git a/dm_control/suite/base.py b/dm_control/suite/base.py
index a09e51d2..7e561c09 100644
--- a/dm_control/suite/base.py
+++ b/dm_control/suite/base.py
@@ -15,22 +15,18 @@
"""Base class for tasks in the Control Suite."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
from dm_control import mujoco
from dm_control.rl import control
-
import numpy as np
class Task(control.Task):
"""Base class for tasks in the Control Suite.
- Maps actions directly to the states of MuJoCo actuators.
+ Actions are mapped directly to the states of MuJoCo actuators: each element of
+ the action array is used to set the control input for a single actuator. The
+ ordering of the actuators is the same as in the corresponding MJCF XML file.
Attributes:
random: A `numpy.random.RandomState` instance. This should be used to
@@ -63,17 +59,22 @@ def action_spec(self, physics):
"""Returns a `BoundedArraySpec` matching the `physics` actuators."""
return mujoco.action_spec(physics)
+ def initialize_episode(self, physics):
+ """Resets geom colors to their defaults after starting a new episode.
+
+ Subclasses of `base.Task` must delegate to this method after performing
+ their own initialization.
+
+ Args:
+ physics: An instance of `mujoco.Physics`.
+ """
+ self.after_step(physics)
+
def before_step(self, action, physics):
"""Sets the control signal for the actuators to values in `action`."""
# Support legacy internal code.
- try:
- physics.set_control(action.continuous_actions)
- except AttributeError:
- physics.set_control(action)
-
- # Reset any reward visualisation at the start of a new episode.
- if self._visualize_reward and physics.time() == 0.0:
- _set_reward_colors(physics, reward=0.0)
+ action = getattr(action, "continuous_actions", action)
+ physics.set_control(action)
def after_step(self, physics):
"""Modifies colors according to the reward."""
@@ -92,15 +93,16 @@ def visualize_reward(self, value):
self._visualize_reward = value
+_MATERIALS = ["self", "effector", "target"]
+_DEFAULT = [name + "_default" for name in _MATERIALS]
+_HIGHLIGHT = [name + "_highlight" for name in _MATERIALS]
+
+
def _set_reward_colors(physics, reward):
"""Sets the highlight, effector and target colors according to the reward."""
assert 0.0 <= reward <= 1.0
-
colors = physics.named.model.mat_rgba
-
- def blend(color1, color2):
- return reward * colors[color1] + (1.0 - reward) * colors[color2]
-
- colors["self"] = blend("self_highlight", "self_default")
- colors["effector"] = blend("effector_highlight", "effector_default")
- colors["target"] = blend("target_highlight", "target_default")
+ default = colors[_DEFAULT]
+ highlight = colors[_HIGHLIGHT]
+ blend_coef = reward ** 4 # Better color distinction near high rewards.
+ colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default
diff --git a/dm_control/suite/cartpole.py b/dm_control/suite/cartpole.py
index 18775f69..4b6a026e 100644
--- a/dm_control/suite/cartpole.py
+++ b/dm_control/suite/cartpole.py
@@ -15,24 +15,16 @@
"""Cartpole domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
-
from lxml import etree
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
_DEFAULT_TIME_LIMIT = 10
@@ -45,51 +37,69 @@ def get_model_and_assets(num_poles=1):
@SUITE.add('benchmarking')
-def balance(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def balance(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns the Cartpole Balance task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Balance(swing_up=False, sparse=False, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add('benchmarking')
-def balance_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def balance_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns the sparse reward variant of the Cartpole Balance task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Balance(swing_up=False, sparse=True, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add('benchmarking')
-def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, **kwargs):
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns the Cartpole Swing-Up task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Balance(swing_up=True, sparse=False, random=random)
- return control.Environment(physics, task, time_limit=time_limit, **kwargs)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add('benchmarking')
-def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None):
- """Returns the sparse reward variant of teh Cartpole Swing-Up task."""
+def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the sparse reward variant of the Cartpole Swing-Up task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Balance(swing_up=True, sparse=True, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add()
-def two_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None):
- """Returns the Cartpole Balance task."""
+def two_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Cartpole Balance task with two poles."""
physics = Physics.from_xml_string(*get_model_and_assets(num_poles=2))
task = Balance(swing_up=True, sparse=False, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add()
-def three_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None):
- """Returns the Cartpole Balance task."""
- physics = Physics.from_xml_string(*get_model_and_assets(num_poles=3))
- task = Balance(swing_up=True, sparse=False, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+def three_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None, num_poles=3,
+ sparse=False, environment_kwargs=None):
+ """Returns the Cartpole Balance task with three or more poles."""
+ physics = Physics.from_xml_string(*get_model_and_assets(num_poles=num_poles))
+ task = Balance(swing_up=True, sparse=sparse, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
def _make_model(n_poles):
@@ -100,7 +110,7 @@ def _make_model(n_poles):
mjcf = etree.fromstring(xml_string)
parent = mjcf.find('./worldbody/body/body') # Find first pole.
# Make chain of poles.
- for pole_index in xrange(2, n_poles+1):
+ for pole_index in range(2, n_poles+1):
child = etree.Element('body', name='pole_{}'.format(pole_index),
pos='0 0 1', childclass='pole')
etree.SubElement(child, 'joint', name='hinge_{}'.format(pole_index))
@@ -162,7 +172,7 @@ def __init__(self, swing_up, sparse, random=None):
"""
self._sparse = sparse
self._swing_up = swing_up
- super(Balance, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode.
@@ -182,6 +192,7 @@ def initialize_episode(self, physics):
physics.named.data.qpos['slider'] = self.random.uniform(-.1, .1)
physics.named.data.qpos[1:] = self.random.uniform(-.034, .034, nv - 1)
physics.named.data.qvel[:] = 0.01 * self.random.randn(physics.model.nv)
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of the (bounded) physics state."""
diff --git a/dm_control/suite/cartpole.xml b/dm_control/suite/cartpole.xml
index af638e5f..e01869dd 100644
--- a/dm_control/suite/cartpole.xml
+++ b/dm_control/suite/cartpole.xml
@@ -17,7 +17,7 @@
-
+
diff --git a/dm_control/suite/cheetah.py b/dm_control/suite/cheetah.py
index 9b26bf7b..109d71b7 100644
--- a/dm_control/suite/cheetah.py
+++ b/dm_control/suite/cheetah.py
@@ -15,14 +15,8 @@
"""Cheetah Domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -46,11 +40,13 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the run task."""
physics = Physics.from_xml_string(*get_model_and_assets())
- task = Cheetah(random)
- return control.Environment(physics, task, time_limit=time_limit)
+ task = Cheetah(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -58,37 +54,32 @@ class Physics(mujoco.Physics):
def speed(self):
"""Returns the horizontal speed of the Cheetah."""
- return self.named.data.subtree_linvel['torso', 'x']
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
class Cheetah(base.Task):
"""A `Task` to train a running Cheetah."""
- def __init__(self, random=None):
- """Initializes an instance of `Cheetah`.
-
- Args:
- random: Optional, either a `numpy.random.RandomState` instance, an
- integer seed for creating a new `RandomState`, or None to select a seed
- automatically (default).
- """
- super(Cheetah, self).__init__(random)
-
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
+ # The indexing below assumes that all joints have a single DOF.
+ assert physics.model.nq == physics.model.njnt
+ is_limited = physics.model.jnt_limited == 1
+ lower, upper = physics.model.jnt_range[is_limited].T
+ physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
# Stabilize the model before the actual simulation.
- for _ in range(200):
- physics.step()
+ physics.step(nstep=200)
physics.data.time = 0
self._timeout_progress = 0
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of the state, ignoring horizontal position."""
obs = collections.OrderedDict()
# Ignores horizontal position to maintain translational invariance.
- obs['position'] = physics.data.qpos[1:]
+ obs['position'] = physics.data.qpos[1:].copy()
obs['velocity'] = physics.velocity()
return obs
diff --git a/dm_control/suite/cheetah.xml b/dm_control/suite/cheetah.xml
index ef396644..32be2872 100644
--- a/dm_control/suite/cheetah.xml
+++ b/dm_control/suite/cheetah.xml
@@ -58,6 +58,10 @@
+
+
+
+
@@ -66,5 +70,4 @@
-
diff --git a/dm_control/suite/common/__init__.py b/dm_control/suite/common/__init__.py
index 0c636473..4518ceb2 100644
--- a/dm_control/suite/common/__init__.py
+++ b/dm_control/suite/common/__init__.py
@@ -15,18 +15,14 @@
"""Functions to manage the common assets for domains."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import os
-from dm_control.utils import resources
+from dm_control.utils import io as resources
_SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
_FILENAMES = [
- "common/materials.xml",
- "common/skybox.xml",
- "common/visual.xml",
+ "./common/materials.xml",
+ "./common/skybox.xml",
+ "./common/visual.xml",
]
ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
diff --git a/dm_control/suite/common/materials.xml b/dm_control/suite/common/materials.xml
index cae6635e..396b58d5 100644
--- a/dm_control/suite/common/materials.xml
+++ b/dm_control/suite/common/materials.xml
@@ -18,5 +18,6 @@ for example receiving a positive reward.
+
diff --git a/dm_control/suite/demos/mocap_demo.py b/dm_control/suite/demos/mocap_demo.py
index 427daf9e..9b8601ad 100644
--- a/dm_control/suite/demos/mocap_demo.py
+++ b/dm_control/suite/demos/mocap_demo.py
@@ -22,19 +22,12 @@
CMU motion capture clips are available at mocap.cs.cmu.edu
"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import time
-# Internal dependencies.
from absl import app
from absl import flags
-
from dm_control.suite import humanoid_CMU
from dm_control.suite.utils import parse_amc
-
import matplotlib.pyplot as plt
import numpy as np
diff --git a/dm_control/suite/dog.py b/dm_control/suite/dog.py
new file mode 100644
index 00000000..1b574072
--- /dev/null
+++ b/dm_control/suite/dog.py
@@ -0,0 +1,449 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Dog Domain."""
+
+import collections
+import os
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+
+from lxml import etree
+import numpy as np
+
+from dm_control.utils import io as resources
+
+_DEFAULT_TIME_LIMIT = 15
+_CONTROL_TIMESTEP = .015
+
+# Angle (in degrees) of local z from global z below which upright reward is 1.
+_MAX_UPRIGHT_ANGLE = 30
+_MIN_UPRIGHT_COSINE = np.cos(np.deg2rad(_MAX_UPRIGHT_ANGLE))
+
+# Standing reward is 1 for body-over-foot height that is at least this fraction
+# of the height at the default pose.
+_STAND_HEIGHT_FRACTION = 0.9
+
+# Torques which enforce joint range limits should stay below this value.
+_EXCESSIVE_LIMIT_TORQUES = 150
+
+# Horizontal speed above which Move reward is 1.
+_WALK_SPEED = 1
+_TROT_SPEED = 3
+_RUN_SPEED = 9
+
+_HINGE_TYPE = mujoco.wrapper.mjbindings.enums.mjtJoint.mjJNT_HINGE
+_LIMIT_TYPE = mujoco.wrapper.mjbindings.enums.mjtConstraint.mjCNSTR_LIMIT_JOINT
+
+SUITE = containers.TaggedTasks()
+
+_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'dog_assets')
+
+
+def make_model(floor_size, remove_ball):
+ """Sets floor size, removes ball and walls (Stand and Move tasks)."""
+ xml_string = common.read_model('dog.xml')
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # set floor size.
+ floor = xml_tools.find_element(mjcf, 'geom', 'floor')
+ floor.attrib['size'] = str(floor_size) + ' ' + str(floor_size) + ' .1'
+
+ if remove_ball:
+ # Remove ball, target and walls.
+ ball = xml_tools.find_element(mjcf, 'body', 'ball')
+ ball.getparent().remove(ball)
+ target = xml_tools.find_element(mjcf, 'geom', 'target')
+ target.getparent().remove(target)
+ ball_cam = xml_tools.find_element(mjcf, 'camera', 'ball')
+ ball_cam.getparent().remove(ball_cam)
+ head_cam = xml_tools.find_element(mjcf, 'camera', 'head')
+ head_cam.getparent().remove(head_cam)
+ for wall_name in ['px', 'nx', 'py', 'ny']:
+ wall = xml_tools.find_element(mjcf, 'geom', 'wall_' + wall_name)
+ wall.getparent().remove(wall)
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+def get_model_and_assets(floor_size=10, remove_ball=True):
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ assets = common.ASSETS.copy()
+ _, _, filenames = next(resources.WalkResources(_ASSET_DIR))
+ for filename in filenames:
+ assets[filename] = resources.GetResource(os.path.join(_ASSET_DIR, filename))
+ return make_model(floor_size, remove_ball), assets
+
+
+@SUITE.add('no_reward_visualization')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Stand task."""
+ floor_size = _WALK_SPEED * _DEFAULT_TIME_LIMIT
+ physics = Physics.from_xml_string(*get_model_and_assets(floor_size))
+ task = Stand(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('no_reward_visualization')
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ move_speed = _WALK_SPEED
+ floor_size = move_speed * _DEFAULT_TIME_LIMIT
+ physics = Physics.from_xml_string(*get_model_and_assets(floor_size))
+ task = Move(move_speed=move_speed, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('no_reward_visualization')
+def trot(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Trot task."""
+ move_speed = _TROT_SPEED
+ floor_size = move_speed * _DEFAULT_TIME_LIMIT
+ physics = Physics.from_xml_string(*get_model_and_assets(floor_size))
+ task = Move(move_speed=move_speed, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('no_reward_visualization')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run task."""
+ move_speed = _RUN_SPEED
+ floor_size = move_speed * _DEFAULT_TIME_LIMIT
+ physics = Physics.from_xml_string(*get_model_and_assets(floor_size))
+ task = Move(move_speed=move_speed, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('no_reward_visualization', 'hard')
+def fetch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Fetch task."""
+ physics = Physics.from_xml_string(*get_model_and_assets(remove_ball=False))
+ task = Fetch(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Dog domain."""
+
+ def torso_pelvis_height(self):
+ """Returns the height of the torso."""
+ return self.named.data.xpos[['torso', 'pelvis'], 'z']
+
+ def z_projection(self):
+ """Returns rotation-invariant projection of local frames to the world z."""
+ return np.vstack((self.named.data.xmat['skull', ['zx', 'zy', 'zz']],
+ self.named.data.xmat['torso', ['zx', 'zy', 'zz']],
+ self.named.data.xmat['pelvis', ['zx', 'zy', 'zz']]))
+
+ def upright(self):
+ """Returns projection from local z-axes to the z-axis of world."""
+ return self.z_projection()[:, 2]
+
+ def center_of_mass_velocity(self):
+ """Returns the velocity of the center-of-mass."""
+ return self.named.data.sensordata['torso_linvel']
+
+ def torso_com_velocity(self):
+ """Returns the velocity of the center-of-mass in the torso frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3).copy()
+ return self.center_of_mass_velocity().dot(torso_frame)
+
+ def com_forward_velocity(self):
+ """Returns the com velocity in the torso's forward direction."""
+ return self.torso_com_velocity()[0]
+
+ def joint_angles(self):
+ """Returns the configuration of all hinge joints (skipping free joints)."""
+ hinge_joints = self.model.jnt_type == _HINGE_TYPE
+ qpos_index = self.model.jnt_qposadr[hinge_joints]
+ return self.data.qpos[qpos_index].copy()
+
+ def joint_velocities(self):
+ """Returns the velocity of all hinge joints (skipping free joints)."""
+ hinge_joints = self.model.jnt_type == _HINGE_TYPE
+ qvel_index = self.model.jnt_dofadr[hinge_joints]
+ return self.data.qvel[qvel_index].copy()
+
+ def inertial_sensors(self):
+ """Returns inertial sensor readings."""
+ return self.named.data.sensordata[['accelerometer', 'velocimeter', 'gyro']]
+
+ def touch_sensors(self):
+ """Returns touch readings."""
+ return self.named.data.sensordata[['palm_L', 'palm_R', 'sole_L', 'sole_R']]
+
+ def foot_forces(self):
+ """Returns touch readings."""
+ return self.named.data.sensordata[['foot_L', 'foot_R', 'hand_L', 'hand_R']]
+
+ def ball_in_head_frame(self):
+ """Returns the ball position and velocity in the frame of the head."""
+ head_frame = self.named.data.site_xmat['head'].reshape(3, 3)
+ head_pos = self.named.data.site_xpos['head']
+ ball_pos = self.named.data.geom_xpos['ball']
+ head_to_ball = ball_pos - head_pos
+ head_vel, _ = self.data.object_velocity('head', 'site')
+ ball_vel, _ = self.data.object_velocity('ball', 'geom')
+ head_to_ball_vel = ball_vel - head_vel
+ return np.hstack((head_to_ball.dot(head_frame),
+ head_to_ball_vel.dot(head_frame)))
+
+ def target_in_head_frame(self):
+ """Returns the target position in the frame of the head."""
+ head_frame = self.named.data.site_xmat['head'].reshape(3, 3)
+ head_pos = self.named.data.site_xpos['head']
+ target_pos = self.named.data.geom_xpos['target']
+ head_to_target = target_pos - head_pos
+ return head_to_target.dot(head_frame)
+
+ def ball_to_mouth_distance(self):
+ """Returns the distance from the ball to the mouth."""
+ ball_pos = self.named.data.geom_xpos['ball']
+ upper_bite_pos = self.named.data.site_xpos['upper_bite']
+ lower_bite_pos = self.named.data.site_xpos['lower_bite']
+ upper_dist = np.linalg.norm(ball_pos - upper_bite_pos)
+ lower_dist = np.linalg.norm(ball_pos - lower_bite_pos)
+ return 0.5*(upper_dist + lower_dist)
+
+ def ball_to_target_distance(self):
+ """Returns the distance from the ball to the target."""
+ ball_pos, target_pos = self.named.data.geom_xpos[['ball', 'target']]
+ return np.linalg.norm(ball_pos - target_pos)
+
+
+class Stand(base.Task):
+ """A dog stand task generating upright posture."""
+
+ def __init__(self, random=None, observe_reward_factors=False):
+ """Initializes an instance of `Stand`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ observe_reward_factors: Boolean, whether the factorised reward is a
+ key in the observation dict returned to the agent.
+ """
+ self._observe_reward_factors = observe_reward_factors
+ super().__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Randomizes initial root velocities and actuator states.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ physics.reset()
+
+ # Measure stand heights from default pose, above which stand reward is 1.
+ self._stand_height = physics.torso_pelvis_height() * _STAND_HEIGHT_FRACTION
+
+ # Measure body weight.
+ body_mass = physics.named.model.body_subtreemass['torso']
+ self._body_weight = -physics.model.opt.gravity[2] * body_mass
+
+ # Randomize horizontal orientation.
+ azimuth = self.random.uniform(0, 2*np.pi)
+ orientation = np.array((np.cos(azimuth/2), 0, 0, np.sin(azimuth/2)))
+ physics.named.data.qpos['root'][3:] = orientation
+
+ # Randomize root velocities in horizontal plane.
+ physics.data.qvel[0] = 2 * self.random.randn()
+ physics.data.qvel[1] = 2 * self.random.randn()
+ physics.data.qvel[5] = 2 * self.random.randn()
+
+ # Randomize actuator states.
+ assert physics.model.nu == physics.model.na
+ for actuator_id in range(physics.model.nu):
+ ctrlrange = physics.model.actuator_ctrlrange[actuator_id]
+ physics.data.act[actuator_id] = self.random.uniform(*ctrlrange)
+
+ def get_observation_components(self, physics):
+ """Returns the observations for the Stand task."""
+ obs = collections.OrderedDict()
+ obs['joint_angles'] = physics.joint_angles()
+ obs['joint_velocites'] = physics.joint_velocities()
+ obs['torso_pelvis_height'] = physics.torso_pelvis_height()
+ obs['z_projection'] = physics.z_projection().flatten()
+ obs['torso_com_velocity'] = physics.torso_com_velocity()
+ obs['inertial_sensors'] = physics.inertial_sensors()
+ obs['foot_forces'] = physics.foot_forces()
+ obs['touch_sensors'] = physics.touch_sensors()
+ obs['actuator_state'] = physics.data.act.copy()
+ return obs
+
+ def get_observation(self, physics):
+ """Returns the observation, possibly adding reward factors."""
+ obs = self.get_observation_components(physics)
+ if self._observe_reward_factors:
+ obs['reward_factors'] = self.get_reward_factors(physics)
+ return obs
+
+ def get_reward_factors(self, physics):
+ """Returns the factorized reward."""
+ # Keep the torso at standing height.
+ torso = rewards.tolerance(physics.torso_pelvis_height()[0],
+ bounds=(self._stand_height[0], float('inf')),
+ margin=self._stand_height[0])
+ # Keep the pelvis at standing height.
+ pelvis = rewards.tolerance(physics.torso_pelvis_height()[1],
+ bounds=(self._stand_height[1], float('inf')),
+ margin=self._stand_height[1])
+ # Keep head, torso and pelvis upright.
+ upright = rewards.tolerance(physics.upright(),
+ bounds=(_MIN_UPRIGHT_COSINE, float('inf')),
+ sigmoid='linear',
+ margin=_MIN_UPRIGHT_COSINE+1,
+ value_at_margin=0)
+
+ # Reward for foot touch forces up to bodyweight.
+ touch = rewards.tolerance(physics.touch_sensors().sum(),
+ bounds=(self._body_weight, float('inf')),
+ margin=self._body_weight,
+ sigmoid='linear',
+ value_at_margin=0.9)
+
+ return np.hstack((torso, pelvis, upright, touch))
+
+ def get_reward(self, physics):
+ """Returns the reward, product of reward factors."""
+ return np.prod(self.get_reward_factors(physics))
+
+
+class Move(Stand):
+ """A dog move task for generating locomotion."""
+
+ def __init__(self, move_speed, random, observe_reward_factors=False):
+ """Initializes an instance of `Move`.
+
+ Args:
+ move_speed: A float. Specifies a target horizontal velocity.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ observe_reward_factors: Boolean, whether the factorised reward is a
+ component of the observation dict.
+ """
+ self._move_speed = move_speed
+ super().__init__(random, observe_reward_factors)
+
+ def get_reward_factors(self, physics):
+ """Returns the factorized reward."""
+ standing = super().get_reward_factors(physics)
+
+ speed_margin = max(1.0, self._move_speed)
+ forward = rewards.tolerance(physics.com_forward_velocity(),
+ bounds=(self._move_speed, 2*self._move_speed),
+ margin=speed_margin,
+ value_at_margin=0,
+ sigmoid='linear')
+ forward = (4*forward + 1) / 5
+
+ return np.hstack((standing, forward))
+
+
+class Fetch(Stand):
+ """A dog fetch task to fetch a thrown ball."""
+
+ def __init__(self, random, observe_reward_factors=False):
+ """Initializes an instance of `Move`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ observe_reward_factors: Boolean, whether the factorised reward is a
+ component of the observation dict.
+ """
+ super().__init__(random, observe_reward_factors)
+
+ def initialize_episode(self, physics):
+ super().initialize_episode(physics)
+
+ # Set initial ball state: flying towards the center at an upward angle.
+ radius = 0.75 * physics.named.model.geom_size['floor', 0]
+ azimuth = self.random.uniform(0, 2*np.pi)
+ position = (radius*np.sin(azimuth), radius*np.cos(azimuth), 0.05)
+ physics.named.data.qpos['ball_root'][:3] = position
+ vertical_height = self.random.uniform(0, 3)
+ # Equating kinetic and potential energy: mv^2/2 = m*g*h -> v = sqrt(2gh)
+ gravity = -physics.model.opt.gravity[2]
+ vertical_velocity = np.sqrt(2 * gravity * vertical_height)
+ horizontal_speed = self.random.uniform(0, 5)
+ # Pointing towards the center, with some noise.
+ direction = np.array((-np.sin(azimuth) + 0.05*self.random.randn(),
+ -np.cos(azimuth) + 0.05*self.random.randn()))
+ horizontal_velocity = horizontal_speed * direction
+ velocity = np.hstack((horizontal_velocity, vertical_velocity))
+ physics.named.data.qvel['ball_root'][:3] = velocity
+
+ def get_observation_components(self, physics):
+ """Returns the common observations for the Stand task."""
+ obs = super().get_observation_components(physics)
+ obs['ball_state'] = physics.ball_in_head_frame()
+ obs['target_position'] = physics.target_in_head_frame()
+ return obs
+
+ def get_reward_factors(self, physics):
+ """Returns a reward to the agent."""
+ standing = super().get_reward_factors(physics)
+
+ # Reward for bringing mouth close to ball.
+ bite_radius = physics.named.model.site_size['upper_bite', 0]
+ bite_margin = 2
+ reach_ball = rewards.tolerance(physics.ball_to_mouth_distance(),
+ bounds=(0, bite_radius),
+ sigmoid='reciprocal',
+ margin=bite_margin)
+ reach_ball = (6*reach_ball + 1) / 7
+
+ # Reward for bringing the ball close to the target.
+ target_radius = physics.named.model.geom_size['target', 0]
+ bring_margin = physics.named.model.geom_size['floor', 0]
+ ball_near_target = rewards.tolerance(
+ physics.ball_to_target_distance(),
+ bounds=(0, target_radius),
+ sigmoid='reciprocal',
+ margin=bring_margin)
+ fetch_ball = (ball_near_target + 1) / 2
+
+ # Let go of the ball if it's been fetched.
+ if physics.ball_to_target_distance() < 2*target_radius:
+ reach_ball = 1
+
+ return np.hstack((standing, reach_ball, fetch_ball))
diff --git a/dm_control/suite/dog.xml b/dm_control/suite/dog.xml
new file mode 100644
index 00000000..6501f3f6
--- /dev/null
+++ b/dm_control/suite/dog.xml
@@ -0,0 +1,999 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/dog_assets/BONEC_1.stl b/dm_control/suite/dog_assets/BONEC_1.stl
new file mode 100644
index 00000000..22823c75
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEC_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEC_2.stl b/dm_control/suite/dog_assets/BONEC_2.stl
new file mode 100644
index 00000000..3b6304b8
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEC_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEC_3.stl b/dm_control/suite/dog_assets/BONEC_3.stl
new file mode 100644
index 00000000..5e0896d6
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEC_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEC_4.stl b/dm_control/suite/dog_assets/BONEC_4.stl
new file mode 100644
index 00000000..45e8484f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEC_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEC_5.stl b/dm_control/suite/dog_assets/BONEC_5.stl
new file mode 100644
index 00000000..04647f1c
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEC_5.stl differ
diff --git a/dm_control/suite/dog_assets/BONEC_6.stl b/dm_control/suite/dog_assets/BONEC_6.stl
new file mode 100644
index 00000000..27e1cf36
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEC_6.stl differ
diff --git a/dm_control/suite/dog_assets/BONEC_7.stl b/dm_control/suite/dog_assets/BONEC_7.stl
new file mode 100644
index 00000000..f5fff989
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEC_7.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_1.stl b/dm_control/suite/dog_assets/BONECa_1.stl
new file mode 100644
index 00000000..fb954ace
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_10.stl b/dm_control/suite/dog_assets/BONECa_10.stl
new file mode 100644
index 00000000..754eef94
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_10.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_11.stl b/dm_control/suite/dog_assets/BONECa_11.stl
new file mode 100644
index 00000000..b7248609
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_11.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_12.stl b/dm_control/suite/dog_assets/BONECa_12.stl
new file mode 100644
index 00000000..bfb4dceb
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_12.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_13.stl b/dm_control/suite/dog_assets/BONECa_13.stl
new file mode 100644
index 00000000..d02310d2
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_13.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_14.stl b/dm_control/suite/dog_assets/BONECa_14.stl
new file mode 100644
index 00000000..ca2ad232
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_14.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_15.stl b/dm_control/suite/dog_assets/BONECa_15.stl
new file mode 100644
index 00000000..98d15a84
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_15.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_16.stl b/dm_control/suite/dog_assets/BONECa_16.stl
new file mode 100644
index 00000000..353bb16b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_16.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_17.stl b/dm_control/suite/dog_assets/BONECa_17.stl
new file mode 100644
index 00000000..2444511a
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_17.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_18.stl b/dm_control/suite/dog_assets/BONECa_18.stl
new file mode 100644
index 00000000..5719d8fa
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_18.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_19.stl b/dm_control/suite/dog_assets/BONECa_19.stl
new file mode 100644
index 00000000..23f5ffb0
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_19.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_2.stl b/dm_control/suite/dog_assets/BONECa_2.stl
new file mode 100644
index 00000000..80cd720c
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_20.stl b/dm_control/suite/dog_assets/BONECa_20.stl
new file mode 100644
index 00000000..c9e5ff58
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_20.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_21.stl b/dm_control/suite/dog_assets/BONECa_21.stl
new file mode 100644
index 00000000..b0aa6016
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_21.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_3.stl b/dm_control/suite/dog_assets/BONECa_3.stl
new file mode 100644
index 00000000..5ade6aa2
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_4.stl b/dm_control/suite/dog_assets/BONECa_4.stl
new file mode 100644
index 00000000..1c9827f9
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_5.stl b/dm_control/suite/dog_assets/BONECa_5.stl
new file mode 100644
index 00000000..0c665086
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_5.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_6.stl b/dm_control/suite/dog_assets/BONECa_6.stl
new file mode 100644
index 00000000..5d45d16c
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_6.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_7.stl b/dm_control/suite/dog_assets/BONECa_7.stl
new file mode 100644
index 00000000..cb12d332
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_7.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_8.stl b/dm_control/suite/dog_assets/BONECa_8.stl
new file mode 100644
index 00000000..d56a43bc
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_8.stl differ
diff --git a/dm_control/suite/dog_assets/BONECa_9.stl b/dm_control/suite/dog_assets/BONECa_9.stl
new file mode 100644
index 00000000..70bb439a
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECa_9.stl differ
diff --git a/dm_control/suite/dog_assets/BONECalcaneal_tuber_L.stl b/dm_control/suite/dog_assets/BONECalcaneal_tuber_L.stl
new file mode 100644
index 00000000..4658f004
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECalcaneal_tuber_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECalcaneal_tuber_R.stl b/dm_control/suite/dog_assets/BONECalcaneal_tuber_R.stl
new file mode 100644
index 00000000..07dfa5be
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECalcaneal_tuber_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_III_L.stl b/dm_control/suite/dog_assets/BONECarpal_III_L.stl
new file mode 100644
index 00000000..1072e9af
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_III_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_III_R.stl b/dm_control/suite/dog_assets/BONECarpal_III_R.stl
new file mode 100644
index 00000000..5a094f89
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_III_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_II_L.stl b/dm_control/suite/dog_assets/BONECarpal_II_L.stl
new file mode 100644
index 00000000..8d999886
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_II_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_II_R.stl b/dm_control/suite/dog_assets/BONECarpal_II_R.stl
new file mode 100644
index 00000000..b40c7e03
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_II_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_IV_L.stl b/dm_control/suite/dog_assets/BONECarpal_IV_L.stl
new file mode 100644
index 00000000..f6214950
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_IV_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_IV_R.stl b/dm_control/suite/dog_assets/BONECarpal_IV_R.stl
new file mode 100644
index 00000000..5bdc0403
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_IV_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_I_L.stl b/dm_control/suite/dog_assets/BONECarpal_I_L.stl
new file mode 100644
index 00000000..831cb992
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_I_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_I_R.stl b/dm_control/suite/dog_assets/BONECarpal_I_R.stl
new file mode 100644
index 00000000..be558771
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_I_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_L.stl b/dm_control/suite/dog_assets/BONECarpal_L.stl
new file mode 100644
index 00000000..f1b38906
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_R.stl b/dm_control/suite/dog_assets/BONECarpal_R.stl
new file mode 100644
index 00000000..b520888f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_Sesamoid_L.stl b/dm_control/suite/dog_assets/BONECarpal_Sesamoid_L.stl
new file mode 100644
index 00000000..6141aea4
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_Sesamoid_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_Sesamoid_R.stl b/dm_control/suite/dog_assets/BONECarpal_Sesamoid_R.stl
new file mode 100644
index 00000000..521c867f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_Sesamoid_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_accessory_L.stl b/dm_control/suite/dog_assets/BONECarpal_accessory_L.stl
new file mode 100644
index 00000000..75440f1d
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_accessory_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_accessory_R.stl b/dm_control/suite/dog_assets/BONECarpal_accessory_R.stl
new file mode 100644
index 00000000..1eb18729
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_accessory_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_ulnar_L.stl b/dm_control/suite/dog_assets/BONECarpal_ulnar_L.stl
new file mode 100644
index 00000000..2d609013
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_ulnar_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONECarpal_ulnar_R.stl b/dm_control/suite/dog_assets/BONECarpal_ulnar_R.stl
new file mode 100644
index 00000000..acc79d01
Binary files /dev/null and b/dm_control/suite/dog_assets/BONECarpal_ulnar_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFemoris_L.stl b/dm_control/suite/dog_assets/BONEFemoris_L.stl
new file mode 100644
index 00000000..bfe3e05c
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFemoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFemoris_R.stl b/dm_control/suite/dog_assets/BONEFemoris_R.stl
new file mode 100644
index 00000000..aa35fc11
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFemoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFemoris_fabellae_L_1.stl b/dm_control/suite/dog_assets/BONEFemoris_fabellae_L_1.stl
new file mode 100644
index 00000000..e08dbd67
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFemoris_fabellae_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFemoris_fabellae_L_2.stl b/dm_control/suite/dog_assets/BONEFemoris_fabellae_L_2.stl
new file mode 100644
index 00000000..885c1b1b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFemoris_fabellae_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFemoris_fabellae_R_1.stl b/dm_control/suite/dog_assets/BONEFemoris_fabellae_R_1.stl
new file mode 100644
index 00000000..5012fab0
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFemoris_fabellae_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFemoris_fabellae_R_2.stl b/dm_control/suite/dog_assets/BONEFemoris_fabellae_R_2.stl
new file mode 100644
index 00000000..9517967e
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFemoris_fabellae_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFibula_L.stl b/dm_control/suite/dog_assets/BONEFibula_L.stl
new file mode 100644
index 00000000..069bc13f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFibula_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEFibula_R.stl b/dm_control/suite/dog_assets/BONEFibula_R.stl
new file mode 100644
index 00000000..01c667c9
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEFibula_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEJaw.stl b/dm_control/suite/dog_assets/BONEJaw.stl
new file mode 100644
index 00000000..be5b325f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEJaw.stl differ
diff --git a/dm_control/suite/dog_assets/BONEL_1.stl b/dm_control/suite/dog_assets/BONEL_1.stl
new file mode 100644
index 00000000..a6f2e298
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEL_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEL_2.stl b/dm_control/suite/dog_assets/BONEL_2.stl
new file mode 100644
index 00000000..7cd3a5c2
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEL_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEL_3.stl b/dm_control/suite/dog_assets/BONEL_3.stl
new file mode 100644
index 00000000..d79992e7
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEL_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEL_4.stl b/dm_control/suite/dog_assets/BONEL_4.stl
new file mode 100644
index 00000000..751eaacb
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEL_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEL_5.stl b/dm_control/suite/dog_assets/BONEL_5.stl
new file mode 100644
index 00000000..d340f1c2
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEL_5.stl differ
diff --git a/dm_control/suite/dog_assets/BONEL_6.stl b/dm_control/suite/dog_assets/BONEL_6.stl
new file mode 100644
index 00000000..5c7fc857
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEL_6.stl differ
diff --git a/dm_control/suite/dog_assets/BONEL_7.stl b/dm_control/suite/dog_assets/BONEL_7.stl
new file mode 100644
index 00000000..5a4ce4e4
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEL_7.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_1.stl b/dm_control/suite/dog_assets/BONELingual_bone_1.stl
new file mode 100644
index 00000000..1e1ecf82
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_2.stl b/dm_control/suite/dog_assets/BONELingual_bone_2.stl
new file mode 100644
index 00000000..8f537f82
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_3.stl b/dm_control/suite/dog_assets/BONELingual_bone_3.stl
new file mode 100644
index 00000000..2cbea4d6
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_4.stl b/dm_control/suite/dog_assets/BONELingual_bone_4.stl
new file mode 100644
index 00000000..9eaecafd
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_5.stl b/dm_control/suite/dog_assets/BONELingual_bone_5.stl
new file mode 100644
index 00000000..f1e9ba4d
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_5.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_6.stl b/dm_control/suite/dog_assets/BONELingual_bone_6.stl
new file mode 100644
index 00000000..faa2f12f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_6.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_7.stl b/dm_control/suite/dog_assets/BONELingual_bone_7.stl
new file mode 100644
index 00000000..24edea25
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_7.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_8.stl b/dm_control/suite/dog_assets/BONELingual_bone_8.stl
new file mode 100644
index 00000000..0e9d5924
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_8.stl differ
diff --git a/dm_control/suite/dog_assets/BONELingual_bone_9.stl b/dm_control/suite/dog_assets/BONELingual_bone_9.stl
new file mode 100644
index 00000000..a3a19ca5
Binary files /dev/null and b/dm_control/suite/dog_assets/BONELingual_bone_9.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMergedSkull.stl b/dm_control/suite/dog_assets/BONEMergedSkull.stl
new file mode 100644
index 00000000..8c39c897
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMergedSkull.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_L_1.stl b/dm_control/suite/dog_assets/BONEMetatarsi_L_1.stl
new file mode 100644
index 00000000..8a7e46b1
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_L_2.stl b/dm_control/suite/dog_assets/BONEMetatarsi_L_2.stl
new file mode 100644
index 00000000..fd14f699
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_L_3.stl b/dm_control/suite/dog_assets/BONEMetatarsi_L_3.stl
new file mode 100644
index 00000000..c9f9f3a2
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_L_4.stl b/dm_control/suite/dog_assets/BONEMetatarsi_L_4.stl
new file mode 100644
index 00000000..843bc6b4
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_L_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_R_1.stl b/dm_control/suite/dog_assets/BONEMetatarsi_R_1.stl
new file mode 100644
index 00000000..50f56732
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_R_2.stl b/dm_control/suite/dog_assets/BONEMetatarsi_R_2.stl
new file mode 100644
index 00000000..4bfad264
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_R_3.stl b/dm_control/suite/dog_assets/BONEMetatarsi_R_3.stl
new file mode 100644
index 00000000..331e0705
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEMetatarsi_R_4.stl b/dm_control/suite/dog_assets/BONEMetatarsi_R_4.stl
new file mode 100644
index 00000000..6dae75c9
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEMetatarsi_R_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_III_L.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_III_L.stl
new file mode 100644
index 00000000..4cf6005d
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_III_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_III_R.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_III_R.stl
new file mode 100644
index 00000000..f791ea6e
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_III_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_II_L.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_II_L.stl
new file mode 100644
index 00000000..cd6095f1
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_II_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_II_R.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_II_R.stl
new file mode 100644
index 00000000..856bd0a5
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_II_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_IV_L.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_IV_L.stl
new file mode 100644
index 00000000..70c3301c
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_IV_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_IV_R.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_IV_R.stl
new file mode 100644
index 00000000..d5f7ed2d
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_IV_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_I_L.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_I_L.stl
new file mode 100644
index 00000000..5807e7c1
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_I_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_I_R.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_I_R.stl
new file mode 100644
index 00000000..75982feb
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_I_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_V_L.001.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_L.001.stl
new file mode 100644
index 00000000..236d65d4
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_L.001.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_V_L.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_L.stl
new file mode 100644
index 00000000..af4af0a7
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_V_R.001.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_R.001.stl
new file mode 100644
index 00000000..38f32794
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_R.001.stl differ
diff --git a/dm_control/suite/dog_assets/BONEOs_metacarpale_V_R.stl b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_R.stl
new file mode 100644
index 00000000..68887942
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEOs_metacarpale_V_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPatella_L.stl b/dm_control/suite/dog_assets/BONEPatella_L.stl
new file mode 100644
index 00000000..c0741a8b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPatella_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPatella_R.stl b/dm_control/suite/dog_assets/BONEPatella_R.stl
new file mode 100644
index 00000000..48283791
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPatella_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPelvis.stl b/dm_control/suite/dog_assets/BONEPelvis.stl
new file mode 100644
index 00000000..3424ee9f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPelvis.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_1.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_1.stl
new file mode 100644
index 00000000..390656d4
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_2.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_2.stl
new file mode 100644
index 00000000..ede8163e
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_3.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_3.stl
new file mode 100644
index 00000000..d955b7bb
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_4.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_4.stl
new file mode 100644
index 00000000..e4de6d60
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_L_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_1.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_1.stl
new file mode 100644
index 00000000..85e75f91
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_2.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_2.stl
new file mode 100644
index 00000000..e85f873e
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_3.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_3.stl
new file mode 100644
index 00000000..0b30d0a3
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_4.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_4.stl
new file mode 100644
index 00000000..e9173173
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_2_R_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_1.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_1.stl
new file mode 100644
index 00000000..1b73fdb7
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_2.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_2.stl
new file mode 100644
index 00000000..6d7c013b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_3.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_3.stl
new file mode 100644
index 00000000..1d76867f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_4.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_4.stl
new file mode 100644
index 00000000..0c3ac662
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_L_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_1.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_1.stl
new file mode 100644
index 00000000..807e9245
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_2.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_2.stl
new file mode 100644
index 00000000..43bb40e9
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_3.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_3.stl
new file mode 100644
index 00000000..02f8c2c5
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_4.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_4.stl
new file mode 100644
index 00000000..2dc164c8
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_3_R_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_L_1.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_L_1.stl
new file mode 100644
index 00000000..4cb14f8d
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_L_2.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_L_2.stl
new file mode 100644
index 00000000..0edc86f4
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_L_3.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_L_3.stl
new file mode 100644
index 00000000..6f4d60bd
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_L_4.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_L_4.stl
new file mode 100644
index 00000000..6aa760f0
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_L_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_R_1.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_R_1.stl
new file mode 100644
index 00000000..7e47eea0
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_R_2.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_R_2.stl
new file mode 100644
index 00000000..92a91477
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_R_3.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_R_3.stl
new file mode 100644
index 00000000..0fead297
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanges_B_R_4.stl b/dm_control/suite/dog_assets/BONEPhalanges_B_R_4.stl
new file mode 100644
index 00000000..97ee5394
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanges_B_R_4.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_III_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_III_L.stl
new file mode 100644
index 00000000..f718c8a9
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_III_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_III_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_III_R.stl
new file mode 100644
index 00000000..2213a1a5
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_III_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_II_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_II_L.stl
new file mode 100644
index 00000000..fef630bc
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_II_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_II_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_II_R.stl
new file mode 100644
index 00000000..2a0572b3
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_II_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_IV_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_IV_L.stl
new file mode 100644
index 00000000..3e9ce552
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_IV_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_IV_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_IV_R.stl
new file mode 100644
index 00000000..f486e395
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_IV_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_I_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_I_L.stl
new file mode 100644
index 00000000..c3d08ea8
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_I_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_I_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_I_R.stl
new file mode 100644
index 00000000..2d6f0f09
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_I_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_V_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_V_L.stl
new file mode 100644
index 00000000..0b82c43b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_V_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_V_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_V_R.stl
new file mode 100644
index 00000000..5d4c4848
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_distalis_digiti_V_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_III_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_III_L.stl
new file mode 100644
index 00000000..18de4f0f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_III_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_III_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_III_R.stl
new file mode 100644
index 00000000..a3d24d7c
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_III_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_II_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_II_L.stl
new file mode 100644
index 00000000..76a005df
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_II_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_II_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_II_R.stl
new file mode 100644
index 00000000..d9336b07
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_II_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_IV_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_IV_L.stl
new file mode 100644
index 00000000..3b92482b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_IV_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_IV_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_IV_R.stl
new file mode 100644
index 00000000..cc55312b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_IV_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_V_L.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_V_L.stl
new file mode 100644
index 00000000..f96ef65a
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_V_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_V_R.stl b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_V_R.stl
new file mode 100644
index 00000000..e8a47cc3
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanx_media_digiti_V_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_III_L_.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_III_L_.stl
new file mode 100644
index 00000000..8d0e99f4
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_III_L_.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_III_R.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_III_R.stl
new file mode 100644
index 00000000..91ba0ca2
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_III_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_II_L.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_II_L.stl
new file mode 100644
index 00000000..08e07fa0
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_II_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_II_R.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_II_R.stl
new file mode 100644
index 00000000..a4f10f6e
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_II_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_IV_L_.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_IV_L_.stl
new file mode 100644
index 00000000..745c209d
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_IV_L_.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_IV_R.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_IV_R.stl
new file mode 100644
index 00000000..6647772c
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_IV_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_V_L.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_V_L.stl
new file mode 100644
index 00000000..713ec124
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_V_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_V_R.stl b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_V_R.stl
new file mode 100644
index 00000000..dca18c00
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEPhalanxpmxinutlis_digiti_V_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONERadius_L.stl b/dm_control/suite/dog_assets/BONERadius_L.stl
new file mode 100644
index 00000000..a469a532
Binary files /dev/null and b/dm_control/suite/dog_assets/BONERadius_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONERadius_R.stl b/dm_control/suite/dog_assets/BONERadius_R.stl
new file mode 100644
index 00000000..87a95b95
Binary files /dev/null and b/dm_control/suite/dog_assets/BONERadius_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONERibcage.stl b/dm_control/suite/dog_assets/BONERibcage.stl
new file mode 100644
index 00000000..0caf3843
Binary files /dev/null and b/dm_control/suite/dog_assets/BONERibcage.stl differ
diff --git a/dm_control/suite/dog_assets/BONESacrum.stl b/dm_control/suite/dog_assets/BONESacrum.stl
new file mode 100644
index 00000000..050a752b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONESacrum.stl differ
diff --git a/dm_control/suite/dog_assets/BONEScapula_L.stl b/dm_control/suite/dog_assets/BONEScapula_L.stl
new file mode 100644
index 00000000..0d649208
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEScapula_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEScapula_R.stl b/dm_control/suite/dog_assets/BONEScapula_R.stl
new file mode 100644
index 00000000..f9f38863
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEScapula_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_L_I.stl b/dm_control/suite/dog_assets/BONETarsus_L_I.stl
new file mode 100644
index 00000000..74f11049
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_L_I.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_L_II.stl b/dm_control/suite/dog_assets/BONETarsus_L_II.stl
new file mode 100644
index 00000000..dedfc8c2
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_L_II.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_L_III.stl b/dm_control/suite/dog_assets/BONETarsus_L_III.stl
new file mode 100644
index 00000000..4cd886b5
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_L_III.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_L_IV.stl b/dm_control/suite/dog_assets/BONETarsus_L_IV.stl
new file mode 100644
index 00000000..b60c9c04
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_L_IV.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_R_I.stl b/dm_control/suite/dog_assets/BONETarsus_R_I.stl
new file mode 100644
index 00000000..949bfe2f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_R_I.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_R_II.stl b/dm_control/suite/dog_assets/BONETarsus_R_II.stl
new file mode 100644
index 00000000..cf43e754
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_R_II.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_R_III.stl b/dm_control/suite/dog_assets/BONETarsus_R_III.stl
new file mode 100644
index 00000000..65b00e25
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_R_III.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_R_IV.stl b/dm_control/suite/dog_assets/BONETarsus_R_IV.stl
new file mode 100644
index 00000000..c4ce31cc
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_R_IV.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_central_L.stl b/dm_control/suite/dog_assets/BONETarsus_central_L.stl
new file mode 100644
index 00000000..764bc5e5
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_central_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONETarsus_central_R.stl b/dm_control/suite/dog_assets/BONETarsus_central_R.stl
new file mode 100644
index 00000000..45af1a3a
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETarsus_central_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONETibia_L.stl b/dm_control/suite/dog_assets/BONETibia_L.stl
new file mode 100644
index 00000000..2e06ff8b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETibia_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONETibia_R.stl b/dm_control/suite/dog_assets/BONETibia_R.stl
new file mode 100644
index 00000000..1682838f
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETibia_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONETibial_tarsal_L.stl b/dm_control/suite/dog_assets/BONETibial_tarsal_L.stl
new file mode 100644
index 00000000..05457d97
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETibial_tarsal_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONETibial_tarsal_R.stl b/dm_control/suite/dog_assets/BONETibial_tarsal_R.stl
new file mode 100644
index 00000000..93747812
Binary files /dev/null and b/dm_control/suite/dog_assets/BONETibial_tarsal_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEUlna_L.stl b/dm_control/suite/dog_assets/BONEUlna_L.stl
new file mode 100644
index 00000000..797b47e6
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEUlna_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEUlna_R.stl b/dm_control/suite/dog_assets/BONEUlna_R.stl
new file mode 100644
index 00000000..5ca584b3
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEUlna_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEXiphoid_cartilage.stl b/dm_control/suite/dog_assets/BONEXiphoid_cartilage.stl
new file mode 100644
index 00000000..46ccf2cd
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEXiphoid_cartilage.stl differ
diff --git a/dm_control/suite/dog_assets/BONEeye_L.stl b/dm_control/suite/dog_assets/BONEeye_L.stl
new file mode 100644
index 00000000..57d1dcb7
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEeye_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEeye_R.stl b/dm_control/suite/dog_assets/BONEeye_R.stl
new file mode 100644
index 00000000..856c2012
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEeye_R.stl differ
diff --git a/dm_control/suite/dog_assets/BONEhumerus_L.stl b/dm_control/suite/dog_assets/BONEhumerus_L.stl
new file mode 100644
index 00000000..512e0ae8
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEhumerus_L.stl differ
diff --git a/dm_control/suite/dog_assets/BONEhumerus_R.stl b/dm_control/suite/dog_assets/BONEhumerus_R.stl
new file mode 100644
index 00000000..ae14546b
Binary files /dev/null and b/dm_control/suite/dog_assets/BONEhumerus_R.stl differ
diff --git a/dm_control/suite/dog_assets/SKINbody.stl b/dm_control/suite/dog_assets/SKINbody.stl
new file mode 100644
index 00000000..42a9018a
Binary files /dev/null and b/dm_control/suite/dog_assets/SKINbody.stl differ
diff --git a/dm_control/suite/dog_assets/dog_skin.msh b/dm_control/suite/dog_assets/dog_skin.msh
new file mode 100644
index 00000000..a2dd67bd
Binary files /dev/null and b/dm_control/suite/dog_assets/dog_skin.msh differ
diff --git a/dm_control/suite/dog_assets/dog_skin.skn b/dm_control/suite/dog_assets/dog_skin.skn
new file mode 100644
index 00000000..e56632a5
Binary files /dev/null and b/dm_control/suite/dog_assets/dog_skin.skn differ
diff --git a/dm_control/suite/dog_assets/extras/BONECanine_Bottom_L.stl b/dm_control/suite/dog_assets/extras/BONECanine_Bottom_L.stl
new file mode 100644
index 00000000..65831142
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONECanine_Bottom_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONECanine_Bottom_R.stl b/dm_control/suite/dog_assets/extras/BONECanine_Bottom_R.stl
new file mode 100644
index 00000000..83b01a9e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONECanine_Bottom_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONECanine_Top_L.stl b/dm_control/suite/dog_assets/extras/BONECanine_Top_L.stl
new file mode 100644
index 00000000..f5b3cf78
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONECanine_Top_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONECanine_Top_R.stl b/dm_control/suite/dog_assets/extras/BONECanine_Top_R.stl
new file mode 100644
index 00000000..ef6cb8e5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONECanine_Top_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEEthmoid.stl b/dm_control/suite/dog_assets/extras/BONEEthmoid.stl
new file mode 100644
index 00000000..d1078e60
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEEthmoid.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_1.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_1.stl
new file mode 100644
index 00000000..58989bcb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_2.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_2.stl
new file mode 100644
index 00000000..ae7d110a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_3.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_3.stl
new file mode 100644
index 00000000..16890e36
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_1.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_1.stl
new file mode 100644
index 00000000..142cd850
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_2.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_2.stl
new file mode 100644
index 00000000..ed694e6c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_3.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_3.stl
new file mode 100644
index 00000000..59f7fb4f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Bottom_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_1.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_1.stl
new file mode 100644
index 00000000..25fd7c14
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_2.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_2.stl
new file mode 100644
index 00000000..5b009a68
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_3.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_3.stl
new file mode 100644
index 00000000..d4add005
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_1.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_1.stl
new file mode 100644
index 00000000..92773ba6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_2.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_2.stl
new file mode 100644
index 00000000..af17817a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_3.stl b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_3.stl
new file mode 100644
index 00000000..6eed6cac
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEIncisors_Top_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMandible.stl b/dm_control/suite/dog_assets/extras/BONEMandible.stl
new file mode 100644
index 00000000..d31458c8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMandible.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_3L.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_3L.stl
new file mode 100644
index 00000000..fff491ec
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_3L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_L_1.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_L_1.stl
new file mode 100644
index 00000000..5758e0f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_L_2.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_L_2.stl
new file mode 100644
index 00000000..fee8a645
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_1.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_1.stl
new file mode 100644
index 00000000..9a77a9aa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_2.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_2.stl
new file mode 100644
index 00000000..0ecedbe5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_3.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_3.stl
new file mode 100644
index 00000000..5d6736bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Bottom_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Top_1L.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Top_1L.stl
new file mode 100644
index 00000000..996ce504
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Top_1L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Top_2L.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Top_2L.stl
new file mode 100644
index 00000000..6ed576ed
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Top_2L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Top_R_1.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Top_R_1.stl
new file mode 100644
index 00000000..ecad7bdb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Top_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEMolars_Top_R_2.stl b/dm_control/suite/dog_assets/extras/BONEMolars_Top_R_2.stl
new file mode 100644
index 00000000..9744b733
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEMolars_Top_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_1.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_1.stl
new file mode 100644
index 00000000..4e193107
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_2.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_2.stl
new file mode 100644
index 00000000..a3283e25
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_3.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_3.stl
new file mode 100644
index 00000000..79ac6a52
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_4.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_4.stl
new file mode 100644
index 00000000..9c01465f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_L_4.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_1.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_1.stl
new file mode 100644
index 00000000..82a51ae6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_2.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_2.stl
new file mode 100644
index 00000000..63bcad11
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_3.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_3.stl
new file mode 100644
index 00000000..205df18b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_4.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_4.stl
new file mode 100644
index 00000000..63afbd29
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Bottom_R_4.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_1.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_1.stl
new file mode 100644
index 00000000..9f0a0056
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_2.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_2.stl
new file mode 100644
index 00000000..66550256
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_3.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_3.stl
new file mode 100644
index 00000000..e995c7b2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_4.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_4.stl
new file mode 100644
index 00000000..df45c4f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_L_4.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_1.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_1.stl
new file mode 100644
index 00000000..83cdf4da
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_2.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_2.stl
new file mode 100644
index 00000000..01cfcbf8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_3.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_3.stl
new file mode 100644
index 00000000..a67119e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_7.stl b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_7.stl
new file mode 100644
index 00000000..9b6c9dc7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEPremolars_Top_R_7.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_1.stl b/dm_control/suite/dog_assets/extras/BONERib_L_1.stl
new file mode 100644
index 00000000..d441ff17
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_10.stl b/dm_control/suite/dog_assets/extras/BONERib_L_10.stl
new file mode 100644
index 00000000..89bded5c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_10.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_11.stl b/dm_control/suite/dog_assets/extras/BONERib_L_11.stl
new file mode 100644
index 00000000..9e3f99a2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_11.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_12.stl b/dm_control/suite/dog_assets/extras/BONERib_L_12.stl
new file mode 100644
index 00000000..ea878659
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_12.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_13.stl b/dm_control/suite/dog_assets/extras/BONERib_L_13.stl
new file mode 100644
index 00000000..6de07ab8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_13.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_2.stl b/dm_control/suite/dog_assets/extras/BONERib_L_2.stl
new file mode 100644
index 00000000..6aae7af7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_3.stl b/dm_control/suite/dog_assets/extras/BONERib_L_3.stl
new file mode 100644
index 00000000..4d0e73e4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_4.stl b/dm_control/suite/dog_assets/extras/BONERib_L_4.stl
new file mode 100644
index 00000000..a46e93f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_4.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_5.stl b/dm_control/suite/dog_assets/extras/BONERib_L_5.stl
new file mode 100644
index 00000000..4cb392ed
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_5.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_6.stl b/dm_control/suite/dog_assets/extras/BONERib_L_6.stl
new file mode 100644
index 00000000..99c3a3b5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_6.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_7.stl b/dm_control/suite/dog_assets/extras/BONERib_L_7.stl
new file mode 100644
index 00000000..23d321dd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_7.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_8.stl b/dm_control/suite/dog_assets/extras/BONERib_L_8.stl
new file mode 100644
index 00000000..ae65550e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_8.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_L_9.stl b/dm_control/suite/dog_assets/extras/BONERib_L_9.stl
new file mode 100644
index 00000000..5bf740e2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_L_9.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_1.stl b/dm_control/suite/dog_assets/extras/BONERib_R_1.stl
new file mode 100644
index 00000000..c953ba1b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_10.stl b/dm_control/suite/dog_assets/extras/BONERib_R_10.stl
new file mode 100644
index 00000000..996662ab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_10.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_11.stl b/dm_control/suite/dog_assets/extras/BONERib_R_11.stl
new file mode 100644
index 00000000..bdf2eb23
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_11.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_12.stl b/dm_control/suite/dog_assets/extras/BONERib_R_12.stl
new file mode 100644
index 00000000..2a14aa13
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_12.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_13.stl b/dm_control/suite/dog_assets/extras/BONERib_R_13.stl
new file mode 100644
index 00000000..5602ab9a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_13.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_2.stl b/dm_control/suite/dog_assets/extras/BONERib_R_2.stl
new file mode 100644
index 00000000..2f8ea131
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_3.stl b/dm_control/suite/dog_assets/extras/BONERib_R_3.stl
new file mode 100644
index 00000000..89b38aca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_4.stl b/dm_control/suite/dog_assets/extras/BONERib_R_4.stl
new file mode 100644
index 00000000..3b3a71e1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_4.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_5.stl b/dm_control/suite/dog_assets/extras/BONERib_R_5.stl
new file mode 100644
index 00000000..47864887
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_5.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_6.stl b/dm_control/suite/dog_assets/extras/BONERib_R_6.stl
new file mode 100644
index 00000000..50157091
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_6.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_7.stl b/dm_control/suite/dog_assets/extras/BONERib_R_7.stl
new file mode 100644
index 00000000..7b7a43d7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_7.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_8.stl b/dm_control/suite/dog_assets/extras/BONERib_R_8.stl
new file mode 100644
index 00000000..b7a79ea7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_8.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONERib_R_9.stl b/dm_control/suite/dog_assets/extras/BONERib_R_9.stl
new file mode 100644
index 00000000..5a7718fa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONERib_R_9.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESkull.stl b/dm_control/suite/dog_assets/extras/BONESkull.stl
new file mode 100644
index 00000000..71db76d8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESkull.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_1.stl b/dm_control/suite/dog_assets/extras/BONESternum_1.stl
new file mode 100644
index 00000000..3c4e14be
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_2.stl b/dm_control/suite/dog_assets/extras/BONESternum_2.stl
new file mode 100644
index 00000000..5a92b667
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_3.stl b/dm_control/suite/dog_assets/extras/BONESternum_3.stl
new file mode 100644
index 00000000..db0d392d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_4.stl b/dm_control/suite/dog_assets/extras/BONESternum_4.stl
new file mode 100644
index 00000000..4dc3684b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_4.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_5.stl b/dm_control/suite/dog_assets/extras/BONESternum_5.stl
new file mode 100644
index 00000000..db42eb14
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_5.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_6.stl b/dm_control/suite/dog_assets/extras/BONESternum_6.stl
new file mode 100644
index 00000000..575df46d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_6.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_7.stl b/dm_control/suite/dog_assets/extras/BONESternum_7.stl
new file mode 100644
index 00000000..0eb16268
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_7.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONESternum_8.stl b/dm_control/suite/dog_assets/extras/BONESternum_8.stl
new file mode 100644
index 00000000..f0a47fc4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONESternum_8.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_1.stl b/dm_control/suite/dog_assets/extras/BONET_1.stl
new file mode 100644
index 00000000..e76a4d63
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_10.stl b/dm_control/suite/dog_assets/extras/BONET_10.stl
new file mode 100644
index 00000000..7030f39a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_10.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_11.stl b/dm_control/suite/dog_assets/extras/BONET_11.stl
new file mode 100644
index 00000000..4702a014
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_11.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_12.stl b/dm_control/suite/dog_assets/extras/BONET_12.stl
new file mode 100644
index 00000000..df8d8ea3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_12.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_13.stl b/dm_control/suite/dog_assets/extras/BONET_13.stl
new file mode 100644
index 00000000..6549b9e7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_13.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_2.stl b/dm_control/suite/dog_assets/extras/BONET_2.stl
new file mode 100644
index 00000000..6e6d7a32
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_3.stl b/dm_control/suite/dog_assets/extras/BONET_3.stl
new file mode 100644
index 00000000..10a57467
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_4.stl b/dm_control/suite/dog_assets/extras/BONET_4.stl
new file mode 100644
index 00000000..c2fb483f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_4.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_5.stl b/dm_control/suite/dog_assets/extras/BONET_5.stl
new file mode 100644
index 00000000..0b7c9ff0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_5.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_6.stl b/dm_control/suite/dog_assets/extras/BONET_6.stl
new file mode 100644
index 00000000..9fc9b57c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_6.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_7.stl b/dm_control/suite/dog_assets/extras/BONET_7.stl
new file mode 100644
index 00000000..aeef3c3d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_7.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_8.stl b/dm_control/suite/dog_assets/extras/BONET_8.stl
new file mode 100644
index 00000000..2e2a4a2d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_8.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONET_9.stl b/dm_control/suite/dog_assets/extras/BONET_9.stl
new file mode 100644
index 00000000..1507de87
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONET_9.stl differ
diff --git a/dm_control/suite/dog_assets/extras/BONEVomer.stl b/dm_control/suite/dog_assets/extras/BONEVomer.stl
new file mode 100644
index 00000000..0d355dc8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/BONEVomer.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCEye_L.stl b/dm_control/suite/dog_assets/extras/MUSCEye_L.stl
new file mode 100644
index 00000000..78bb6dfb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCEye_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCEye_R.stl b/dm_control/suite/dog_assets/extras/MUSCEye_R.stl
new file mode 100644
index 00000000..3edb1854
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCEye_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCLigament.stl b/dm_control/suite/dog_assets/extras/MUSCLigament.stl
new file mode 100644
index 00000000..3344dcc8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCLigament.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_Fascia_thoracolumbar_.stl b/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_Fascia_thoracolumbar_.stl
new file mode 100644
index 00000000..ee980db0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_Fascia_thoracolumbar_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_scaleni_dors_L.stl b/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_scaleni_dors_L.stl
new file mode 100644
index 00000000..c15b9fbf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_scaleni_dors_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_scaleni_dors_R.stl b/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_scaleni_dors_R.stl
new file mode 100644
index 00000000..c3d3a318
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCMm.dorsi_scaleni_dors_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCMm.thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/MUSCMm.thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..1a57f115
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCMm.thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCMm.thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/MUSCMm.thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..d332c286
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCMm.thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCNose.stl b/dm_control/suite/dog_assets/extras/MUSCNose.stl
new file mode 100644
index 00000000..29188e5f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCNose.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCSkutulum_L.stl b/dm_control/suite/dog_assets/extras/MUSCSkutulum_L.stl
new file mode 100644
index 00000000..fcb34be1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCSkutulum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCSkutulum_R.stl b/dm_control/suite/dog_assets/extras/MUSCSkutulum_R.stl
new file mode 100644
index 00000000..7f7aa4d0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCSkutulum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCabdominis_Aponeurosis.stl b/dm_control/suite/dog_assets/extras/MUSCabdominis_Aponeurosis.stl
new file mode 100644
index 00000000..02ad56b7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCabdominis_Aponeurosis.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCdiaphragma.stl b/dm_control/suite/dog_assets/extras/MUSCdiaphragma.stl
new file mode 100644
index 00000000..8e3a929a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCdiaphragma.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm._sternohyroideus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm._sternohyroideus_L.stl
new file mode 100644
index 00000000..800919a7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm._sternohyroideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm._sternohyroideus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm._sternohyroideus_R.stl
new file mode 100644
index 00000000..6b8c5475
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm._sternohyroideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.abductor_cruris_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.abductor_cruris_L.stl
new file mode 100644
index 00000000..6e80e28e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.abductor_cruris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.abductor_cruris_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.abductor_cruris_R.stl
new file mode 100644
index 00000000..0b20aa29
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.abductor_cruris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.abductor_pollicis_longus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.abductor_pollicis_longus_L.stl
new file mode 100644
index 00000000..01bf7afb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.abductor_pollicis_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.abductor_pollicis_longus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.abductor_pollicis_longus_R.stl
new file mode 100644
index 00000000..327a32db
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.abductor_pollicis_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.adductor_longus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.adductor_longus_L.stl
new file mode 100644
index 00000000..150823e4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.adductor_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.adductor_longus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.adductor_longus_R.stl
new file mode 100644
index 00000000..db428c7f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.adductor_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.adductor_magnus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.adductor_magnus_L.stl
new file mode 100644
index 00000000..896c55fd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.adductor_magnus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.adductor_magnus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.adductor_magnus_R.stl
new file mode 100644
index 00000000..4c098d62
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.adductor_magnus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.anconeus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.anconeus_L.stl
new file mode 100644
index 00000000..b0ae5403
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.anconeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.anconeus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.anconeus_R.stl
new file mode 100644
index 00000000..17dd27e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.anconeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.biceps_brachii_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.biceps_brachii_L.stl
new file mode 100644
index 00000000..77a31f5a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.biceps_brachii_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.biceps_brachii_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.biceps_brachii_R.stl
new file mode 100644
index 00000000..a9bd941d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.biceps_brachii_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.brachialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.brachialis_L.stl
new file mode 100644
index 00000000..cf609dba
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.brachialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.brachialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.brachialis_R.stl
new file mode 100644
index 00000000..5f1c0fed
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.brachialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.brachiocephalicus.stl b/dm_control/suite/dog_assets/extras/MUSCm.brachiocephalicus.stl
new file mode 100644
index 00000000..9e076fb2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.brachiocephalicus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.buccinator_L(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_L(2).stl
new file mode 100644
index 00000000..6a3f2022
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.buccinator_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_L.stl
new file mode 100644
index 00000000..c61824d2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.buccinator_R(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_R(2).stl
new file mode 100644
index 00000000..6733c4b3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.buccinator_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_R.stl
new file mode 100644
index 00000000..80a69b54
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.buccinator_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.caninus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.caninus_L.stl
new file mode 100644
index 00000000..0b05d675
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.caninus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.caninus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.caninus_R.stl
new file mode 100644
index 00000000..7272aced
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.caninus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.capsularis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.capsularis_L.stl
new file mode 100644
index 00000000..fc6654cc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.capsularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.capsularis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.capsularis_R.stl
new file mode 100644
index 00000000..2c0cf4dd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.capsularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.coccygeus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.coccygeus_L.stl
new file mode 100644
index 00000000..772bff11
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.coccygeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.coccygeus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.coccygeus_R.stl
new file mode 100644
index 00000000..94032200
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.coccygeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.coracobrachialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.coracobrachialis_L.stl
new file mode 100644
index 00000000..f7c376ca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.coracobrachialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.coracobrachialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.coracobrachialis_R.stl
new file mode 100644
index 00000000..26c74dea
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.coracobrachialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_acromialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_acromialis_L.stl
new file mode 100644
index 00000000..336d91a2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_acromialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_acromialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_acromialis_R.stl
new file mode 100644
index 00000000..e68dd458
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_acromialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_scapularis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_scapularis_L.stl
new file mode 100644
index 00000000..dfe321f3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_scapularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_scapularis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_scapularis_R.stl
new file mode 100644
index 00000000..e3802a62
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.deltoideus_scapularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..8baa56e1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L2(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L2(2).stl
new file mode 100644
index 00000000..7907934a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L2(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..8cf5e155
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..e92ddff6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R2(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R2(2).stl
new file mode 100644
index 00000000..662f7fd3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R2(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..db5ec360
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.digastricus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.digastricus_L.stl
new file mode 100644
index 00000000..96d24638
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.digastricus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.digastricus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.digastricus_R.stl
new file mode 100644
index 00000000..d227936d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.digastricus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_radialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_radialis_R.stl
new file mode 100644
index 00000000..49e68a92
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_radialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_ulnaris_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_ulnaris_L.stl
new file mode 100644
index 00000000..cbc39078
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_ulnaris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_ulnaris_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_ulnaris_R.stl
new file mode 100644
index 00000000..23cab4fa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_capri_ulnaris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_communis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_communis_L.stl
new file mode 100644
index 00000000..36b2eab7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_communis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_communis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_communis_R.stl
new file mode 100644
index 00000000..0760593c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_communis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lat_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lat_L.stl
new file mode 100644
index 00000000..d65c94f4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lat_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lat_R.stl
new file mode 100644
index 00000000..961ab791
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lateralis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lateralis_L.stl
new file mode 100644
index 00000000..0fc34988
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lateralis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lateralis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lateralis_R.stl
new file mode 100644
index 00000000..9a0a7c75
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_lateralis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_longus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_longus_L.stl
new file mode 100644
index 00000000..dc34c59c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_longus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_longus_R.stl
new file mode 100644
index 00000000..dc3633ca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.extensor_digitorum_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_capri_radialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_capri_radialis_L.stl
new file mode 100644
index 00000000..b86e9663
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_capri_radialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_capri_radialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_capri_radialis_R.stl
new file mode 100644
index 00000000..ce6391f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_capri_radialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_carpi_ulnaris_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_carpi_ulnaris_L.stl
new file mode 100644
index 00000000..24854a30
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_carpi_ulnaris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_carpi_ulnaris_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_carpi_ulnaris_R.stl
new file mode 100644
index 00000000..71d1a280
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_carpi_ulnaris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_L(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_L(2).stl
new file mode 100644
index 00000000..0acf9c8d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..a26c6672
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_R(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_R(2).stl
new file mode 100644
index 00000000..e66c3cdc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..f46e1c01
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_L.001.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_L.001.stl
new file mode 100644
index 00000000..286aff52
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_L.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_L.stl
new file mode 100644
index 00000000..bc9d39b3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_R.stl
new file mode 100644
index 00000000..873cab42
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.flexor_digitorum_superficialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gastrocnemius_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.gastrocnemius_L.stl
new file mode 100644
index 00000000..7956b54c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gastrocnemius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gastrocnemius_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.gastrocnemius_R.stl
new file mode 100644
index 00000000..ac6ffcce
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gastrocnemius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gemellus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.gemellus_L.stl
new file mode 100644
index 00000000..c9d90e6b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gemellus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gemellus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.gemellus_R.stl
new file mode 100644
index 00000000..dc0fdb56
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gemellus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.geniohyoideus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.geniohyoideus_L.stl
new file mode 100644
index 00000000..1592995d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.geniohyoideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.geniohyoideus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.geniohyoideus_R.stl
new file mode 100644
index 00000000..c6a9de42
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.geniohyoideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gluteus_medius_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_medius_L.stl
new file mode 100644
index 00000000..0882d056
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_medius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gluteus_medius_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_medius_R.stl
new file mode 100644
index 00000000..da075ad5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_medius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gluteus_profundus_minor_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_profundus_minor_L.stl
new file mode 100644
index 00000000..49435da3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_profundus_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gluteus_profundus_minor_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_profundus_minor_R.stl
new file mode 100644
index 00000000..ac94fb9f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_profundus_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gluteus_superficialis_major_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_superficialis_major_L.stl
new file mode 100644
index 00000000..ba1dabed
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_superficialis_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gluteus_superficialis_major_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_superficialis_major_R.stl
new file mode 100644
index 00000000..f2d093e8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gluteus_superficialis_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gracilis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.gracilis_L.stl
new file mode 100644
index 00000000..028060a4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gracilis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.gracilis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.gracilis_R.stl
new file mode 100644
index 00000000..814cef63
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.gracilis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.iliacus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.iliacus_L.stl
new file mode 100644
index 00000000..d0e108d0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.iliacus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.iliacus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.iliacus_R.stl
new file mode 100644
index 00000000..aaa65a0d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.iliacus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.infraspinatus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.infraspinatus_L.stl
new file mode 100644
index 00000000..552fb14c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.infraspinatus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.infraspinatus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.infraspinatus_R.stl
new file mode 100644
index 00000000..c288ac0b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.infraspinatus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.interscutularis.stl b/dm_control/suite/dog_assets/extras/MUSCm.interscutularis.stl
new file mode 100644
index 00000000..6a06ce06
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.interscutularis.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.latissimus_dors_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.latissimus_dors_L.stl
new file mode 100644
index 00000000..6ecdebce
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.latissimus_dors_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.latissimus_dors_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.latissimus_dors_R.stl
new file mode 100644
index 00000000..8546e4e8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.latissimus_dors_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_ani_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_ani_L.stl
new file mode 100644
index 00000000..ac0f00fd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_ani_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_ani_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_ani_R.stl
new file mode 100644
index 00000000..47d88386
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_ani_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L.stl
new file mode 100644
index 00000000..9d594b02
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L1.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L1.stl
new file mode 100644
index 00000000..31779562
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..c6c484ef
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R.stl
new file mode 100644
index 00000000..ef315fdb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R1.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R1.stl
new file mode 100644
index 00000000..d4e28689
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..40d0ca87
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_longus_L.stl
new file mode 100644
index 00000000..06f49a62
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_longus_R.stl
new file mode 100644
index 00000000..8f2b65b0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_labii_sup_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_labii_sup_L.stl
new file mode 100644
index 00000000..8c34cb02
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_labii_sup_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.levator_labii_sup_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.levator_labii_sup_R.stl
new file mode 100644
index 00000000..a56386ba
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.levator_labii_sup_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus__lumborum_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus__lumborum_L.stl
new file mode 100644
index 00000000..5d5af9fd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus__lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus__lumborum_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus__lumborum_R.stl
new file mode 100644
index 00000000..d8adf1d6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus__lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_L(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_L(2).stl
new file mode 100644
index 00000000..a017c519
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_L.stl
new file mode 100644
index 00000000..4ae5dd90
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_R(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_R(2).stl
new file mode 100644
index 00000000..0e0f7eea
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_R.stl
new file mode 100644
index 00000000..68df52b8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..085516a6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..65736261
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_thoracis_L.stl
new file mode 100644
index 00000000..8aa418a8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_thoracis_R.stl
new file mode 100644
index 00000000..00112d12
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.longus_capitis_L.stl
new file mode 100644
index 00000000..a766a6c1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.longus_capitis_R.stl
new file mode 100644
index 00000000..3b1f018a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.masseter_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.masseter_L.stl
new file mode 100644
index 00000000..b05d46a6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.masseter_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.masseter_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.masseter_R.stl
new file mode 100644
index 00000000..58b4e968
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.masseter_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.mylohyoideus.stl b/dm_control/suite/dog_assets/extras/MUSCm.mylohyoideus.stl
new file mode 100644
index 00000000..97c4624f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.mylohyoideus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_caudalis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_caudalis_L.stl
new file mode 100644
index 00000000..55889865
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_caudalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_caudalis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_caudalis_R.stl
new file mode 100644
index 00000000..e78be01e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_caudalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_L(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_L(2).stl
new file mode 100644
index 00000000..47676424
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_L.stl
new file mode 100644
index 00000000..72f556db
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_R(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_R(2).stl
new file mode 100644
index 00000000..4e1e1c35
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_R.stl
new file mode 100644
index 00000000..7b8a8b80
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_capitis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_ext_abdominis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_ext_abdominis_L.stl
new file mode 100644
index 00000000..989e2794
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_ext_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_ext_abdominis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_ext_abdominis_R.stl
new file mode 100644
index 00000000..224324f0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_ext_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_int_abdominis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_int_abdominis_L.stl
new file mode 100644
index 00000000..1b8facec
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_int_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obliquus_int_abdominis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_int_abdominis_R.stl
new file mode 100644
index 00000000..3292f438
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obliquus_int_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_ext_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_ext_L.stl
new file mode 100644
index 00000000..f955d3f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_ext_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_ext_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_ext_R.stl
new file mode 100644
index 00000000..59f88765
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_ext_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_int_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_int_L.stl
new file mode 100644
index 00000000..2385f8c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_int_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_int_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_int_R.stl
new file mode 100644
index 00000000..fd187373
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.obturatorius_int_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.omotransversarius_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.omotransversarius_L.stl
new file mode 100644
index 00000000..60e0ede2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.omotransversarius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.omotransversarius_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.omotransversarius_R.stl
new file mode 100644
index 00000000..6faf0dd3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.omotransversarius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oculi.stl b/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oculi.stl
new file mode 100644
index 00000000..c01e9581
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oculi.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oris_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oris_L.stl
new file mode 100644
index 00000000..e7bfa275
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oris_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oris_R.stl
new file mode 100644
index 00000000..b9813cc6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.orbicularis_oris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pectineus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.pectineus_L.stl
new file mode 100644
index 00000000..6387ab44
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pectineus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pectineus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.pectineus_R.stl
new file mode 100644
index 00000000..dd2414ab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pectineus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pectorales_superficiales_1.stl b/dm_control/suite/dog_assets/extras/MUSCm.pectorales_superficiales_1.stl
new file mode 100644
index 00000000..8cab1940
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pectorales_superficiales_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pectorales_superficiales_2.stl b/dm_control/suite/dog_assets/extras/MUSCm.pectorales_superficiales_2.stl
new file mode 100644
index 00000000..d8e4f324
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pectorales_superficiales_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.peroneus_brevis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_brevis_L.stl
new file mode 100644
index 00000000..74ce74b0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.peroneus_brevis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_brevis_R.stl
new file mode 100644
index 00000000..32d3dc4f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.peroneus_longus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_longus_L.stl
new file mode 100644
index 00000000..7234f573
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.peroneus_longus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_longus_R.stl
new file mode 100644
index 00000000..ba80dc97
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.peroneus_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.popliteus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.popliteus_L.stl
new file mode 100644
index 00000000..cae9b0ed
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.popliteus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.popliteus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.popliteus_R.stl
new file mode 100644
index 00000000..97ff61cd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.popliteus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pronator_quadratus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.pronator_quadratus_L.stl
new file mode 100644
index 00000000..b39e49f7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pronator_quadratus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pronator_quadratus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.pronator_quadratus_R.stl
new file mode 100644
index 00000000..50eeeed4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pronator_quadratus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pronator_teres_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.pronator_teres_L.stl
new file mode 100644
index 00000000..405dd608
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pronator_teres_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pronator_teres_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.pronator_teres_R.stl
new file mode 100644
index 00000000..f2eaf03c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pronator_teres_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.psoas_major_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.psoas_major_L.stl
new file mode 100644
index 00000000..db9d2736
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.psoas_major_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.psoas_major_R.stl
new file mode 100644
index 00000000..642d10fb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.psoas_minor_L.stl
new file mode 100644
index 00000000..e2dfb0ab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.psoas_minor_R.stl
new file mode 100644
index 00000000..e97749f5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_lat_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_lat_L.stl
new file mode 100644
index 00000000..fb698ce5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_lat_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_lat_R.stl
new file mode 100644
index 00000000..fee2ef18
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_med_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_med_L.stl
new file mode 100644
index 00000000..c0195c0a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_med_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_med_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_med_R.stl
new file mode 100644
index 00000000..513e8a84
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.pterygoideus_med_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.quadratus_femoris_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.quadratus_femoris_L.stl
new file mode 100644
index 00000000..62a22dab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.quadratus_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.quadratus_femoris_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.quadratus_femoris_R.stl
new file mode 100644
index 00000000..8d9c35f7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.quadratus_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_abdominis_L.stl
new file mode 100644
index 00000000..4c7d1313
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_abdominis_R.stl
new file mode 100644
index 00000000..5fed4fa2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_major_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_major_L.stl
new file mode 100644
index 00000000..c30ec6c8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_major_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_major_R.stl
new file mode 100644
index 00000000..47858d53
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_minor_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_minor_L.stl
new file mode 100644
index 00000000..ffe591f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_minor_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_minor_R.stl
new file mode 100644
index 00000000..bcd038ee
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_dorsalis_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_lateralis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_lateralis_L.stl
new file mode 100644
index 00000000..3ffee81f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_lateralis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_lateralis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_lateralis_R.stl
new file mode 100644
index 00000000..23e2d0cb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_lateralis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_ventr_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_ventr_L.stl
new file mode 100644
index 00000000..63bde44b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_ventr_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_ventr_R.stl
new file mode 100644
index 00000000..5da69143
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_capitis_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_femoris_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_femoris_L.stl
new file mode 100644
index 00000000..11c2bb35
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_femoris_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_femoris_R.stl
new file mode 100644
index 00000000..5708f7b8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_thoracis_L.stl
new file mode 100644
index 00000000..bd59095e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rectus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rectus_thoracis_R.stl
new file mode 100644
index 00000000..c731c145
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rectus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.retractor_anguli_oculi_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.retractor_anguli_oculi_L.stl
new file mode 100644
index 00000000..18f8ddc2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.retractor_anguli_oculi_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.retractor_anguli_oculi_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.retractor_anguli_oculi_R.stl
new file mode 100644
index 00000000..fb69ce32
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.retractor_anguli_oculi_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rhomboideus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.rhomboideus_L.stl
new file mode 100644
index 00000000..3cb4fcf2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rhomboideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.rhomboideus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.rhomboideus_R.stl
new file mode 100644
index 00000000..9f52194e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.rhomboideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sartorius_L(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_L(2).stl
new file mode 100644
index 00000000..f047be39
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sartorius_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_L.stl
new file mode 100644
index 00000000..83bd2ea5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sartorius_R(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_R(2).stl
new file mode 100644
index 00000000..dd4d5f8b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sartorius_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_R.stl
new file mode 100644
index 00000000..780c64cf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sartorius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.semimembranosus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.semimembranosus_L.stl
new file mode 100644
index 00000000..eca5fb88
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.semimembranosus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.semimembranosus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.semimembranosus_R.stl
new file mode 100644
index 00000000..a82074cf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.semimembranosus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.semitendinosus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.semitendinosus_L.stl
new file mode 100644
index 00000000..eec6670a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.semitendinosus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.semitendinosus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.semitendinosus_R.stl
new file mode 100644
index 00000000..fad6cb0a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.semitendinosus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L1.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L1.stl
new file mode 100644
index 00000000..1e28aec4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L2.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L2.stl
new file mode 100644
index 00000000..abf0f4a3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L3.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L3.stl
new file mode 100644
index 00000000..58906090
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R1.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R1.stl
new file mode 100644
index 00000000..9436068f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R2.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R2.stl
new file mode 100644
index 00000000..f2584164
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R3.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R3.stl
new file mode 100644
index 00000000..7eb216bb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_caudalis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..8d9be4bb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..671eba0b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.splenius_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.splenius_L.stl
new file mode 100644
index 00000000..82765d60
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.splenius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.splenius_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.splenius_R.stl
new file mode 100644
index 00000000..560e7309
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.splenius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sternocephalicus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.sternocephalicus_L.stl
new file mode 100644
index 00000000..90f6879c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sternocephalicus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sternocephalicus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.sternocephalicus_R.stl
new file mode 100644
index 00000000..71b63ef8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sternocephalicus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sternothyroideus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.sternothyroideus_L.stl
new file mode 100644
index 00000000..8d1baad6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sternothyroideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.sternothyroideus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.sternothyroideus_R.stl
new file mode 100644
index 00000000..e98605b6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.sternothyroideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.subscapularis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.subscapularis_L.stl
new file mode 100644
index 00000000..31259df4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.subscapularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.subscapularis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.subscapularis_R.stl
new file mode 100644
index 00000000..74c57955
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.subscapularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.supinator_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.supinator_L.stl
new file mode 100644
index 00000000..471bded8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.supinator_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.supinator_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.supinator_R.stl
new file mode 100644
index 00000000..db5c65bc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.supinator_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.supraspinatus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.supraspinatus_L.stl
new file mode 100644
index 00000000..02673453
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.supraspinatus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.supraspinatus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.supraspinatus_R.stl
new file mode 100644
index 00000000..fb527718
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.supraspinatus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.temporalis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.temporalis_L.stl
new file mode 100644
index 00000000..fc50b8e4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.temporalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.temporalis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.temporalis_R.stl
new file mode 100644
index 00000000..6daa0b14
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.temporalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tensor_f.antebrachii_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.tensor_f.antebrachii_L.stl
new file mode 100644
index 00000000..f8b7aefd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tensor_f.antebrachii_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tensor_f.antebrachii_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.tensor_f.antebrachii_R.stl
new file mode 100644
index 00000000..19b5ae87
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tensor_f.antebrachii_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.teres_major_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.teres_major_L.stl
new file mode 100644
index 00000000..6d73dfc4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.teres_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.teres_major_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.teres_major_R.stl
new file mode 100644
index 00000000..5d37cbdf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.teres_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.teres_minor_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.teres_minor_L.stl
new file mode 100644
index 00000000..cf9bfb6d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.teres_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.teres_minor_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.teres_minor_R.stl
new file mode 100644
index 00000000..a50ae6d9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.teres_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tibialis_caudalis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_caudalis_L.stl
new file mode 100644
index 00000000..422e6d60
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_caudalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tibialis_caudalis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_caudalis_R.stl
new file mode 100644
index 00000000..6a6ed5d3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_caudalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tibialis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_cranialis_L.stl
new file mode 100644
index 00000000..b54032bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tibialis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_cranialis_R.stl
new file mode 100644
index 00000000..b7561e42
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tibialis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.transverses_thoracis_L.stl
new file mode 100644
index 00000000..6b865470
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.transverses_thoracis_R.stl
new file mode 100644
index 00000000..45fb7ef1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.tranversus_abdominis_L.stl
new file mode 100644
index 00000000..e6076d1b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.tranversus_abdominis_R.stl
new file mode 100644
index 00000000..eba23786
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.trapezius.stl b/dm_control/suite/dog_assets/extras/MUSCm.trapezius.stl
new file mode 100644
index 00000000..437100f9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_accessory_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_accessory_L.stl
new file mode 100644
index 00000000..3db0733d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_accessory_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_accessory_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_accessory_R.stl
new file mode 100644
index 00000000..c266d0c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_accessory_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_lateral_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_lateral_L.stl
new file mode 100644
index 00000000..ca080715
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_lateral_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_L(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_L(2).stl
new file mode 100644
index 00000000..f6525c8f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_L.stl
new file mode 100644
index 00000000..749d7942
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_R(2).stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_R(2).stl
new file mode 100644
index 00000000..1c2cb58c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_R.stl
new file mode 100644
index 00000000..22e966dc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_long_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_medial_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_medial_L.stl
new file mode 100644
index 00000000..c9338247
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_medial_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_medial_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_medial_R.stl
new file mode 100644
index 00000000..4b75626b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.triceps_brachii_medial_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.tricepsbrachii_lateral_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.tricepsbrachii_lateral_R.stl
new file mode 100644
index 00000000..8c926a32
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.tricepsbrachii_lateral_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.vastus_intermedius_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.vastus_intermedius_L.stl
new file mode 100644
index 00000000..c5a17cf8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.vastus_intermedius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.vastus_intermedius_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.vastus_intermedius_R.stl
new file mode 100644
index 00000000..a5f5543e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.vastus_intermedius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.vastus_lat_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.vastus_lat_L.stl
new file mode 100644
index 00000000..169c44fb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.vastus_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.vastus_lat_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.vastus_lat_R.stl
new file mode 100644
index 00000000..cf108c4b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.vastus_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.vastus_med_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.vastus_med_L.stl
new file mode 100644
index 00000000..64cffe98
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.vastus_med_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.vastus_med_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.vastus_med_R.stl
new file mode 100644
index 00000000..36cbec33
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.vastus_med_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.zygomaticus_L.stl b/dm_control/suite/dog_assets/extras/MUSCm.zygomaticus_L.stl
new file mode 100644
index 00000000..9d585fbb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.zygomaticus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm.zygomaticus_R.stl b/dm_control/suite/dog_assets/extras/MUSCm.zygomaticus_R.stl
new file mode 100644
index 00000000..f8efdbbf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm.zygomaticus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_L.stl b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_L.stl
new file mode 100644
index 00000000..cdc5d3f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_R.stl b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_R.stl
new file mode 100644
index 00000000..aef93a45
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_tendon_L.stl b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_tendon_L.stl
new file mode 100644
index 00000000..54d218ed
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_tendon_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_tendon_R.stl b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_tendon_R.stl
new file mode 100644
index 00000000..b12fa360
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_biceps_femoris_tendon_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_extensor_capri_radialis_L.stl b/dm_control/suite/dog_assets/extras/MUSCm_extensor_capri_radialis_L.stl
new file mode 100644
index 00000000..95f76d2d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_extensor_capri_radialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_extensor_digitorum_lat_L.stl b/dm_control/suite/dog_assets/extras/MUSCm_extensor_digitorum_lat_L.stl
new file mode 100644
index 00000000..3c20ea11
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_extensor_digitorum_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_tensor_f.latae_L.stl b/dm_control/suite/dog_assets/extras/MUSCm_tensor_f.latae_L.stl
new file mode 100644
index 00000000..1a9edf78
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_tensor_f.latae_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCm_tensor_f.latae_R.stl b/dm_control/suite/dog_assets/extras/MUSCm_tensor_f.latae_R.stl
new file mode 100644
index 00000000..4e7b4a9a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCm_tensor_f.latae_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_ext.L.stl b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_ext.L.stl
new file mode 100644
index 00000000..0437296a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_ext.L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_ext_R.stl b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_ext_R.stl
new file mode 100644
index 00000000..3aed61bc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_ext_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_int.L.stl b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_int.L.stl
new file mode 100644
index 00000000..e21d1a75
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_int.L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_int_R.stl b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_int_R.stl
new file mode 100644
index 00000000..a3e7242a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCmm.intercostals_int_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCmm.thoracis_m.pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/MUSCmm.thoracis_m.pectoralis_profundus.stl
new file mode 100644
index 00000000..9575f802
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCmm.thoracis_m.pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCvastus_lat_ander_fascia_lata_L.stl b/dm_control/suite/dog_assets/extras/MUSCvastus_lat_ander_fascia_lata_L.stl
new file mode 100644
index 00000000..079ca071
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCvastus_lat_ander_fascia_lata_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/MUSCvastus_lat_ander_fascia_lata_R.stl b/dm_control/suite/dog_assets/extras/MUSCvastus_lat_ander_fascia_lata_R.stl
new file mode 100644
index 00000000..5aa69028
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/MUSCvastus_lat_ander_fascia_lata_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITE-1=m_obliquus_capitis_caudalis_L.stl b/dm_control/suite/dog_assets/extras/SITE-1=m_obliquus_capitis_caudalis_L.stl
new file mode 100644
index 00000000..15e340d7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITE-1=m_obliquus_capitis_caudalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITE-1=m_obliquus_capitis_caudalis_R.stl b/dm_control/suite/dog_assets/extras/SITE-1=m_obliquus_capitis_caudalis_R.stl
new file mode 100644
index 00000000..e59ae70d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITE-1=m_obliquus_capitis_caudalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITE-2=m_obliquus_capitis_caudalis_L.stl b/dm_control/suite/dog_assets/extras/SITE-2=m_obliquus_capitis_caudalis_L.stl
new file mode 100644
index 00000000..3fa8b896
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITE-2=m_obliquus_capitis_caudalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITE-2=m_obliquus_capitis_caudalis_R.stl b/dm_control/suite/dog_assets/extras/SITE-2=m_obliquus_capitis_caudalis_R.stl
new file mode 100644
index 00000000..f68ac986
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITE-2=m_obliquus_capitis_caudalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_L(2).stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_L(2).stl
new file mode 100644
index 00000000..9ec0843a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_L.stl
new file mode 100644
index 00000000..f300e283
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_R(2).stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_R(2).stl
new file mode 100644
index 00000000..fc548b7d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_R.stl
new file mode 100644
index 00000000..4a42ca25
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_obliquus_capitis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_dorsalis_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_dorsalis_minor_L.stl
new file mode 100644
index 00000000..df96ff65
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_dorsalis_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_dorsalis_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_dorsalis_minor_R.stl
new file mode 100644
index 00000000..8f2a0f29
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_dorsalis_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_lateralis_L.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_lateralis_L.stl
new file mode 100644
index 00000000..16412980
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_lateralis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_lateralis_R.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_lateralis_R.stl
new file mode 100644
index 00000000..50d60f0b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_lateralis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_ventr_L.stl
new file mode 100644
index 00000000..7c7215b0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_ventr_R.stl
new file mode 100644
index 00000000..8379a1ee
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-1=m_rectus_capitis_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-3=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITEC-3=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..ed238a03
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-3=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-3=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITEC-3=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..f439e255
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-3=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-4=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITEC-4=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..93bed907
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-4=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-4=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITEC-4=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..3d43cb10
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-4=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-5=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITEC-5=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..9497ab4f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-5=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-5=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITEC-5=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..c0dddcd6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-5=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-6=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITEC-6=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..11545d5a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-6=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-6=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITEC-6=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..e79c42ae
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-6=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-7=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITEC-7=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..d5dd45c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-7=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC-7=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITEC-7=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..61763ab3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC-7=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_1=m_brachiocephalicus_.stl b/dm_control/suite/dog_assets/extras/SITEC_1=m_brachiocephalicus_.stl
new file mode 100644
index 00000000..b11a7324
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_1=m_brachiocephalicus_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_1=m_omotransversarius_L.stl b/dm_control/suite/dog_assets/extras/SITEC_1=m_omotransversarius_L.stl
new file mode 100644
index 00000000..24566909
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_1=m_omotransversarius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_1=m_omotransversarius_R.stl b/dm_control/suite/dog_assets/extras/SITEC_1=m_omotransversarius_R.stl
new file mode 100644
index 00000000..d4779499
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_1=m_omotransversarius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_2=m_brachiocephalicus_.stl b/dm_control/suite/dog_assets/extras/SITEC_2=m_brachiocephalicus_.stl
new file mode 100644
index 00000000..99fa512a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_2=m_brachiocephalicus_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_2=m_rectus_capitis_dorsalis_major_L.stl b/dm_control/suite/dog_assets/extras/SITEC_2=m_rectus_capitis_dorsalis_major_L.stl
new file mode 100644
index 00000000..aa676ff8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_2=m_rectus_capitis_dorsalis_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_2=m_rectus_capitis_dorsalis_major_R.stl b/dm_control/suite/dog_assets/extras/SITEC_2=m_rectus_capitis_dorsalis_major_R.stl
new file mode 100644
index 00000000..75e3ac12
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_2=m_rectus_capitis_dorsalis_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_2=m_splenius_L.stl b/dm_control/suite/dog_assets/extras/SITEC_2=m_splenius_L.stl
new file mode 100644
index 00000000..ea1722b3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_2=m_splenius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_2=m_splenius_R.stl b/dm_control/suite/dog_assets/extras/SITEC_2=m_splenius_R.stl
new file mode 100644
index 00000000..9952e82f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_2=m_splenius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_3=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITEC_3=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..3ab57b4b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_3=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_3=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITEC_3=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..6f031e26
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_3=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_3=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEC_3=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..54d717ad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_3=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_3=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEC_3=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..b25d23cb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_3=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_3=m_longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEC_3=m_longus_capitis_L.stl
new file mode 100644
index 00000000..d6ab92c6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_3=m_longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_3=m_longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEC_3=m_longus_capitis_R.stl
new file mode 100644
index 00000000..5c70ad3c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_3=m_longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_4=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITEC_4=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..75d29e21
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_4=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_4=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITEC_4=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..b0e3b854
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_4=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_4=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEC_4=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..fea590b9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_4=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_4=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEC_4=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..788ae7c6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_4=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_4=m_longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEC_4=m_longus_capitis_L.stl
new file mode 100644
index 00000000..9f623228
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_4=m_longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_4=m_longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEC_4=m_longus_capitis_R.stl
new file mode 100644
index 00000000..2f73eec9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_4=m_longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_5=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITEC_5=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..75c34612
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_5=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_5=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITEC_5=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..2e055865
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_5=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_6=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITEC_6=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..3e4aa539
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_6=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_6=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITEC_6=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..22200438
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_6=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITEC_7=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..3ad04d64
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITEC_7=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..89b71192
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..d49020e1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..17281a68
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..79d0f5e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..dd766c2b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=m_longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEC_7=m_longus_capitis_L.stl
new file mode 100644
index 00000000..1e6bedc2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=m_longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEC_7=m_longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEC_7=m_longus_capitis_R.stl
new file mode 100644
index 00000000..5e074a4f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEC_7=m_longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..d53e258d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..29701589
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..3297480b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..661c1879
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..377e993d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..48c1a5ab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..92fa9196
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..bf1d99ee
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..3b092f29
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..eb104518
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_10=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..bbdb7974
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..953a76e2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..208962fc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..590c6688
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..b076bdee
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..4a8985e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..e5f74ef8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..647f3dbe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_L.001.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_L.001.stl
new file mode 100644
index 00000000..71406aad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_L.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..71406aad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_R.001.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_R.001.stl
new file mode 100644
index 00000000..ce9eaf73
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_R.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..5737c628
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_11=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..776cc5ba
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..d36d06af
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..7f6d2917
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..22ddb408
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..9932e32e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..59e79288
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..8dd29e30
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..c96d8d5f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..60e2b089
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..391703b8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_12=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..65ceb1f6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..5a2790b0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..043baeb7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..c6dc33d3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..6669759c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..cb9c077e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..216db037
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..f7617b25
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..8bdc2722
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..78d79a1c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_13=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..8fb157b2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..8656eb01
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..3d518c68
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..1ac15b79
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..433beb48
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..96ab73ee
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..386c640d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..e00d60d8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..875c194a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..8e589fdf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_14=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..947582f6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..04193e1b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..0f44fe53
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..e00ab676
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..813bf842
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..45028a17
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..93da8d65
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..cdb03e30
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..96884dfa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..bb25692a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_15=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..6a401621
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..99dd1798
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..9b32f779
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..7ed23c2a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..5cca080e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..87e6000e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..82b1a3e3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..b7145ccf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..566e694e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..990ad921
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_16=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..8d85ac5e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..1efea153
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..407132a4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..73f5a0fa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..c0a24e8b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..10f06afe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..42cbc764
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..509bc9ae
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..b9fc8790
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..91f02c4e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_17=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..209f4057
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..3639b7e2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..f06e551c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..63350ca4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..9b0865e2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..e33fd418
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..38e58638
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..b172128f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..8647dbee
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..a9c9e783
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_18=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..0298efde
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..972e9fe6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..1e932bf8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..3088d24f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..fdf73763
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..99c4dcab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..15202904
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..87f9bf8e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..cbee4298
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..dacf0cc3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_19=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_abductor_cruris_L.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_abductor_cruris_L.stl
new file mode 100644
index 00000000..62121fa9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_abductor_cruris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_abductor_cruris_R.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_abductor_cruris_R.stl
new file mode 100644
index 00000000..c4ead932
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_abductor_cruris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..07005320
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..e3d6abcf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..c6bdb0fb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..acaa364e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..0ca6c822
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..61fc5bfe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_L.001.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_L.001.stl
new file mode 100644
index 00000000..5788ac4b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_L.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_L.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_L.stl
new file mode 100644
index 00000000..3a737c49
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_R.001.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_R.001.stl
new file mode 100644
index 00000000..c54352df
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_R.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_R.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_R.stl
new file mode 100644
index 00000000..5a5891ea
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_gluteus_superficialis_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_L.stl
new file mode 100644
index 00000000..626a6f91
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..ab60982b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_R.stl
new file mode 100644
index 00000000..1d25fd8c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..a2a20b8b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..e773e52c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..861b13ea
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_1=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..d9359299
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..078e83be
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..2f038b57
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..ac8cd004
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..ce57edd4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..14cfe97c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..e3576de6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..052e4dfa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..84a3b419
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..7d4b3ffd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_20=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..d990801a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..6489f3e7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..f7f4588e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..92c0de17
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..659943c3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..07d16131
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..e9b62441
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..650f88c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..a48728bf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..a3e0e4e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_21=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_coccygeus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_coccygeus_L.stl
new file mode 100644
index 00000000..8f1a0fad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_coccygeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_coccygeus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_coccygeus_R.stl
new file mode 100644
index 00000000..85f6b1bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_coccygeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..ba2ac6a8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..f23e7e35
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..640633f4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..0d1afd8c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..c2afc556
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..6af9ad08
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_L1.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_L1.stl
new file mode 100644
index 00000000..b9bb412f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_L1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..3f2c041c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_R1.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_R1.stl
new file mode 100644
index 00000000..a75df77a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_R1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..9d587bd2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..cba390da
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..ec5cfb3b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_2=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_coccygeus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_coccygeus_L.stl
new file mode 100644
index 00000000..c4608657
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_coccygeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_coccygeus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_coccygeus_R.stl
new file mode 100644
index 00000000..46ab4494
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_coccygeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..6fd549af
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..c25bf2f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..cc68f8fa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..1fa08791
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..fc167f43
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..e388a433
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..d6c905c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..6ec446cc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..f914cf81
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..99d87c85
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_3=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_coccygeus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_coccygeus_L.stl
new file mode 100644
index 00000000..4ed80767
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_coccygeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_coccygeus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_coccygeus_R.stl
new file mode 100644
index 00000000..abe72001
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_coccygeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..9c618932
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..cb4fed4f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..20053787
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..f54a9f98
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..402ccfe7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..4b45c57b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_ani_L.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_ani_L.stl
new file mode 100644
index 00000000..cca6291f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_ani_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_ani_R.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_ani_R.stl
new file mode 100644
index 00000000..79eec10c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_ani_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..bcb0252a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..2fcae223
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..a29905f6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..5febbf99
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_4=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..fb83ddd7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..a7fdc546
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..22037f59
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..f9555323
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..a766949e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..b0c7e115
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..78ec17a7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..71aeff1c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..e507bd9e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..70d2913a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_5=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..32bbbb4c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..186569bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..94fd90f6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..91b18905
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..befc504f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..b20ccfda
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..ac38c766
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..705eb2be
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..bd53646f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..0f69eb68
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_6=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..d6eb5e6c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..9367191b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..d08700b5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..6fb7e77b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..91d073e8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..aa6a62fb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..1e4db32a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..dd72301e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..42319e49
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..746628de
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_7=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..14dedb3a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..73d3bfd0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..8ad9a4c9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..bf90b386
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..54674e95
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..8dd4f203
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..c3c67741
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..2c0fa7e1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..6c01fa37
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..83cbd657
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_8=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..b5f11793
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..2f7bfa9b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..99e8a031
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..55833174
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..80e8e9b1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..6b818acf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..75dc1bb6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..70321491
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..d3e081c4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..67146d47
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECa_9=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECalcaneal_tuber_L=m_gastrocnemius_L.stl b/dm_control/suite/dog_assets/extras/SITECalcaneal_tuber_L=m_gastrocnemius_L.stl
new file mode 100644
index 00000000..92c674ea
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECalcaneal_tuber_L=m_gastrocnemius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECalcaneal_tuber_R=m_gastrocnemius_R.stl b/dm_control/suite/dog_assets/extras/SITECalcaneal_tuber_R=m_gastrocnemius_R.stl
new file mode 100644
index 00000000..d1166236
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECalcaneal_tuber_R=m_gastrocnemius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECarpal_II_L=m_extensor_capri_radialis_L.stl b/dm_control/suite/dog_assets/extras/SITECarpal_II_L=m_extensor_capri_radialis_L.stl
new file mode 100644
index 00000000..afc0c6ff
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECarpal_II_L=m_extensor_capri_radialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECarpal_II_R=m_extensor_capri_radialis_R.stl b/dm_control/suite/dog_assets/extras/SITECarpal_II_R=m_extensor_capri_radialis_R.stl
new file mode 100644
index 00000000..d5b59e19
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECarpal_II_R=m_extensor_capri_radialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECarpal_accessory_L=m.flexor_carpi_ulnaris_L.stl b/dm_control/suite/dog_assets/extras/SITECarpal_accessory_L=m.flexor_carpi_ulnaris_L.stl
new file mode 100644
index 00000000..d102d3b1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECarpal_accessory_L=m.flexor_carpi_ulnaris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITECarpal_accessory_R=m.flexor_carpi_ulnaris_R.stl b/dm_control/suite/dog_assets/extras/SITECarpal_accessory_R=m.flexor_carpi_ulnaris_R.stl
new file mode 100644
index 00000000..3dff8ba2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITECarpal_accessory_R=m.flexor_carpi_ulnaris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L1.stl b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L1.stl
new file mode 100644
index 00000000..420d73c0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L2.stl b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L2.stl
new file mode 100644
index 00000000..e213de8e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L3.stl b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L3.stl
new file mode 100644
index 00000000..ffeb6766
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R1.stl b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R1.stl
new file mode 100644
index 00000000..8c3b69d6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R2.stl b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R2.stl
new file mode 100644
index 00000000..f3ddccf0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R3.stl b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R3.stl
new file mode 100644
index 00000000..627d0348
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFascia_thoracolumbar=m_serratus_dorsalis_caudalis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L1=m_gastrocnemius_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L1=m_gastrocnemius_L.stl
new file mode 100644
index 00000000..bf7d9003
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L1=m_gastrocnemius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L2=m_gastrocnemius_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L2=m_gastrocnemius_L.stl
new file mode 100644
index 00000000..4e9ec794
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L2=m_gastrocnemius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_abductor_cruris_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_abductor_cruris_L.stl
new file mode 100644
index 00000000..2eb7e649
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_abductor_cruris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_adductor_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_adductor_longus_L.stl
new file mode 100644
index 00000000..443d92d5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_adductor_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_adductor_magnus_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_adductor_magnus_L.stl
new file mode 100644
index 00000000..600826c1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_adductor_magnus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_capsularis_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_capsularis_L.stl
new file mode 100644
index 00000000..f5a6c043
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_capsularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_extensor_digitorum_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_extensor_digitorum_longus_L.stl
new file mode 100644
index 00000000..df400e86
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_extensor_digitorum_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gemellus_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gemellus_L.stl
new file mode 100644
index 00000000..ec17b6be
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gemellus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_medius_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_medius_L.stl
new file mode 100644
index 00000000..e34770b4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_medius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_profundus_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_profundus_minor_L.stl
new file mode 100644
index 00000000..db9df5d0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_profundus_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_superficialis_major_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_superficialis_major_L.stl
new file mode 100644
index 00000000..2acdf661
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_gluteus_superficialis_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_iliacus_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_iliacus_L.stl
new file mode 100644
index 00000000..a5e3c03f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_iliacus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_obturatorius_ext_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_obturatorius_ext_L.stl
new file mode 100644
index 00000000..3794601c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_obturatorius_ext_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_obturatorius_int_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_obturatorius_int_L.stl
new file mode 100644
index 00000000..93e14932
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_obturatorius_int_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_pectineus_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_pectineus_L.stl
new file mode 100644
index 00000000..9c3ab84a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_pectineus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_popliteus_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_popliteus_L.stl
new file mode 100644
index 00000000..4314df9c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_popliteus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_psoas_major_L.stl
new file mode 100644
index 00000000..b678dfa4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_quadratus_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_quadratus_femoris_L.stl
new file mode 100644
index 00000000..95568edf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_quadratus_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_rectus_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_rectus_femoris_L.stl
new file mode 100644
index 00000000..1e04d021
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_rectus_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_intermedius_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_intermedius_L.stl
new file mode 100644
index 00000000..984a9778
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_intermedius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_lat_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_lat_L.stl
new file mode 100644
index 00000000..2be7b894
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_med_L.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_med_L.stl
new file mode 100644
index 00000000..ec0aa4a5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_L=m_vastus_med_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R1=m_gastrocnemius_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R1=m_gastrocnemius_R.stl
new file mode 100644
index 00000000..e6b93fcb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R1=m_gastrocnemius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R2=m_gastrocnemius_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R2=m_gastrocnemius_R.stl
new file mode 100644
index 00000000..6ffe1ad7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R2=m_gastrocnemius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_abductor_cruris_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_abductor_cruris_R.stl
new file mode 100644
index 00000000..b8841413
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_abductor_cruris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_adductor_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_adductor_longus_R.stl
new file mode 100644
index 00000000..ce75c494
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_adductor_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_adductor_magnus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_adductor_magnus_R.stl
new file mode 100644
index 00000000..f01561ad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_adductor_magnus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_capsularis_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_capsularis_R.stl
new file mode 100644
index 00000000..fcaace4b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_capsularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_extensor_digitorum_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_extensor_digitorum_longus_R.stl
new file mode 100644
index 00000000..5a8368bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_extensor_digitorum_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gemellus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gemellus_R.stl
new file mode 100644
index 00000000..4370ba0a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gemellus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_medius_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_medius_R.stl
new file mode 100644
index 00000000..3cde1c88
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_medius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_profundus_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_profundus_minor_R.stl
new file mode 100644
index 00000000..e7ddaddd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_profundus_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_superficialis_major_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_superficialis_major_R.stl
new file mode 100644
index 00000000..a49ed1c6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_gluteus_superficialis_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_iliacus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_iliacus_R.stl
new file mode 100644
index 00000000..9c376ba4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_iliacus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_obturatorius_ext_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_obturatorius_ext_R.stl
new file mode 100644
index 00000000..1d0fba22
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_obturatorius_ext_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_obturatorius_int_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_obturatorius_int_R.stl
new file mode 100644
index 00000000..f4f76e01
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_obturatorius_int_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_pectineus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_pectineus_R.stl
new file mode 100644
index 00000000..ccd363b7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_pectineus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_popliteus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_popliteus_R.stl
new file mode 100644
index 00000000..d25279ad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_popliteus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_psoas_major_R.stl
new file mode 100644
index 00000000..59be51f0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_quadratus_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_quadratus_femoris_R.stl
new file mode 100644
index 00000000..4b6755c2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_quadratus_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_rectus_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_rectus_femoris_R.stl
new file mode 100644
index 00000000..e31cb7af
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_rectus_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_intermedius_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_intermedius_R.stl
new file mode 100644
index 00000000..1161cba4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_intermedius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_lat_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_lat_R.stl
new file mode 100644
index 00000000..7c6c6994
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_med_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_med_R.stl
new file mode 100644
index 00000000..abbb59ba
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_R=m_vastus_med_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_fabellae_L_2=m_semimembranosus_L_.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_fabellae_L_2=m_semimembranosus_L_.stl
new file mode 100644
index 00000000..af7aa079
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_fabellae_L_2=m_semimembranosus_L_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFemoris_fabellae_R_2=m_semimembranosus_R.stl b/dm_control/suite/dog_assets/extras/SITEFemoris_fabellae_R_2=m_semimembranosus_R.stl
new file mode 100644
index 00000000..46ccd243
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFemoris_fabellae_R_2=m_semimembranosus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_L=m_extensor_digitorum_lateralis_L.stl b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_extensor_digitorum_lateralis_L.stl
new file mode 100644
index 00000000..96e757a7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_extensor_digitorum_lateralis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..61699beb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_L=m_peroneus_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_peroneus_brevis_L.stl
new file mode 100644
index 00000000..2eab1299
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_peroneus_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_L=m_peroneus_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_peroneus_longus_L.stl
new file mode 100644
index 00000000..f19faee0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_peroneus_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_L=m_tibialis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_tibialis_cranialis_L.stl
new file mode 100644
index 00000000..02c9dd4e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_L=m_tibialis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_R=m_extensor_digitorum_lateralis_R.stl b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_extensor_digitorum_lateralis_R.stl
new file mode 100644
index 00000000..9002c2f5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_extensor_digitorum_lateralis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_R=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..87f98654
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_R=m_peroneus_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_peroneus_brevis_R.stl
new file mode 100644
index 00000000..caa3e88e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_peroneus_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_R=m_peroneus_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_peroneus_longus_R.stl
new file mode 100644
index 00000000..63cedf28
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_peroneus_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEFibula_R=m_tibialis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_tibialis_cranialis_R.stl
new file mode 100644
index 00000000..2db84226
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEFibula_R=m_tibialis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_major_L.stl
new file mode 100644
index 00000000..71ffc81d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_major_R.stl
new file mode 100644
index 00000000..632904e0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_minor_L.stl
new file mode 100644
index 00000000..19fe8033
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_minor_R.stl
new file mode 100644
index 00000000..badb9350
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-1=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_major_L.stl
new file mode 100644
index 00000000..d6d03601
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_major_R.stl
new file mode 100644
index 00000000..ca4fd1b1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_minor_L.stl
new file mode 100644
index 00000000..487cdd0e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_minor_R.stl
new file mode 100644
index 00000000..5b3e765c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-2=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_major_L.stl
new file mode 100644
index 00000000..27d73061
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_major_R.stl
new file mode 100644
index 00000000..8bbeacbe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_minor_L.stl
new file mode 100644
index 00000000..02d449d7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_minor_R.stl
new file mode 100644
index 00000000..27c96d69
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-3=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_major_L.stl
new file mode 100644
index 00000000..c37f7954
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_major_R.stl
new file mode 100644
index 00000000..517ff2dd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_minor_L.stl
new file mode 100644
index 00000000..67250f2f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_minor_R.stl
new file mode 100644
index 00000000..91f93375
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-4=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-5=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEL-5=m_psoas_major_L.stl
new file mode 100644
index 00000000..c41ab066
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-5=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-5=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEL-5=m_psoas_major_R.stl
new file mode 100644
index 00000000..e252d0ca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-5=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-6=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEL-6=m_psoas_major_L.stl
new file mode 100644
index 00000000..53ad4e45
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-6=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-6=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEL-6=m_psoas_major_R.stl
new file mode 100644
index 00000000..30e37381
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-6=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-7=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITEL-7=m_psoas_major_L.stl
new file mode 100644
index 00000000..1dabafe2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-7=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL-7=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITEL-7=m_psoas_major_R.stl
new file mode 100644
index 00000000..85fdcea7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL-7=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..3b3bd444
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..12413609
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..77b6cbca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..aa2186ff
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..e71e286e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_1=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_1=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_1=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..39ece820
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_1=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_1=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_1=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..85cb5ed5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_1=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..5c9714b0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..c96cd273
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..89b75044
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..590dfe9b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..89592944
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..4d5f7061
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..55c65c87
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_2=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_2=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..247038f3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_2=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..c9b649f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..37ebd409
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..7b57247a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..7930dbb9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..205b7b9d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..6f35f112
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..12ade64e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_3=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_3=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..b994a33e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_3=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..ccb38c53
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..8b19c29b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..c17c0734
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..8cec455f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..51d1870b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..0d08068e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..e2916a83
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_4=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_4=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..1903f451
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_4=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..46b1866a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..8c135502
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..c7311b4d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..8cf74476
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..f683f899
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..823f4676
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..6c9bd0c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_5=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_5=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..950bd26c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_5=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..ffcd477b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..dd00e511
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..4b236ec3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..96418141
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..f483b959
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..a8aa3527
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..0e04d93c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_6=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_6=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..10825c88
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_6=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..e74f26ab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..66fd93a3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_brevis_L.stl
new file mode 100644
index 00000000..5349d6df
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_brevis_R.stl
new file mode 100644
index 00000000..fcfc3ba4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..f8997b70
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..ad263a24
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..2c3ebb08
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..aab525ac
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..742132a4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..46feef3b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..44eced51
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEL_7=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEL_7=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..91b57646
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEL_7=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-10=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-10=m_trapezius.stl
new file mode 100644
index 00000000..9d39d3aa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-10=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-11=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-11=m_trapezius.stl
new file mode 100644
index 00000000..b4702550
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-11=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-12=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-12=m_trapezius.stl
new file mode 100644
index 00000000..525127b4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-12=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-13=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-13=m_trapezius.stl
new file mode 100644
index 00000000..9c5a95b4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-13=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-1=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-1=m_trapezius.stl
new file mode 100644
index 00000000..a871cc77
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-1=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-2=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-2=m_trapezius.stl
new file mode 100644
index 00000000..bac3952a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-2=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-3=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-3=m_trapezius.stl
new file mode 100644
index 00000000..716b1dca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-3=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-4=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-4=m_trapezius.stl
new file mode 100644
index 00000000..78fed0b1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-4=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-5=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-5=m_trapezius.stl
new file mode 100644
index 00000000..43b21152
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-5=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-6=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-6=m_trapezius.stl
new file mode 100644
index 00000000..ec503444
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-6=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-7=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-7=m_trapezius.stl
new file mode 100644
index 00000000..ec503444
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-7=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-8=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-8=m_trapezius.stl
new file mode 100644
index 00000000..e4e29ba3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-8=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament-9=m_trapezius.stl b/dm_control/suite/dog_assets/extras/SITELigament-9=m_trapezius.stl
new file mode 100644
index 00000000..c1e1c496
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament-9=m_trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament=m_rhomboideus_L.stl b/dm_control/suite/dog_assets/extras/SITELigament=m_rhomboideus_L.stl
new file mode 100644
index 00000000..41166a8a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament=m_rhomboideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament=m_rhomboideus_R.stl b/dm_control/suite/dog_assets/extras/SITELigament=m_rhomboideus_R.stl
new file mode 100644
index 00000000..105a0f5f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament=m_rhomboideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITELigament=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..3309284b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELigament=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITELigament=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..00d633cd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELigament=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELingual_bone_1=m_sternohyroideus_L.stl b/dm_control/suite/dog_assets/extras/SITELingual_bone_1=m_sternohyroideus_L.stl
new file mode 100644
index 00000000..8e16748d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELingual_bone_1=m_sternohyroideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELingual_bone_1=m_sternohyroideus_R.stl b/dm_control/suite/dog_assets/extras/SITELingual_bone_1=m_sternohyroideus_R.stl
new file mode 100644
index 00000000..7e22ea7b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELingual_bone_1=m_sternohyroideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELingual_bone_3=m_sternothyroideus_L.stl b/dm_control/suite/dog_assets/extras/SITELingual_bone_3=m_sternothyroideus_L.stl
new file mode 100644
index 00000000..bef3d99a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELingual_bone_3=m_sternothyroideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITELingual_bone_3=m_sternothyroideus_R.stl b/dm_control/suite/dog_assets/extras/SITELingual_bone_3=m_sternothyroideus_R.stl
new file mode 100644
index 00000000..9bb14bae
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITELingual_bone_3=m_sternothyroideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_digastricus_L.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_digastricus_L.stl
new file mode 100644
index 00000000..e1bea8bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_digastricus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_digastricus_R.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_digastricus_R.stl
new file mode 100644
index 00000000..4409ce59
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_digastricus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_masseter_L.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_masseter_L.stl
new file mode 100644
index 00000000..a4633e28
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_masseter_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_masseter_R.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_masseter_R.stl
new file mode 100644
index 00000000..432ba6ad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_masseter_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_lat_L.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_lat_L.stl
new file mode 100644
index 00000000..81839473
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_lat_R.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_lat_R.stl
new file mode 100644
index 00000000..fe3b955b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_med_L.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_med_L.stl
new file mode 100644
index 00000000..91a701f3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_med_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_med_R.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_med_R.stl
new file mode 100644
index 00000000..8fa6887b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_pterygoideus_med_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_temporalis_L.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_temporalis_L.stl
new file mode 100644
index 00000000..c5a90dee
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_temporalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMandible=m_temporalis_R.stl b/dm_control/suite/dog_assets/extras/SITEMandible=m_temporalis_R.stl
new file mode 100644
index 00000000..ecfbefc3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMandible=m_temporalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMetatarsi_L_4=m_peroneus_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITEMetatarsi_L_4=m_peroneus_brevis_L.stl
new file mode 100644
index 00000000..104c24cc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMetatarsi_L_4=m_peroneus_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMetatarsi_R_4=m_peroneus_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITEMetatarsi_R_4=m_peroneus_brevis_R.stl
new file mode 100644
index 00000000..a95d9c7e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMetatarsi_R_4=m_peroneus_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMm.dorsi_Fascia_thoracolumbar_L=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITEMm.dorsi_Fascia_thoracolumbar_L=m_latissimus_dors.stl
new file mode 100644
index 00000000..a56a520e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMm.dorsi_Fascia_thoracolumbar_L=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEMm.dorsi_Fascia_thoracolumbar_R=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITEMm.dorsi_Fascia_thoracolumbar_R=m_latissimus_dors.stl
new file mode 100644
index 00000000..d0ced8d1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEMm.dorsi_Fascia_thoracolumbar_R=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_III_L=m_flexor_capri_radialis_L.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_III_L=m_flexor_capri_radialis_L.stl
new file mode 100644
index 00000000..02fdc69f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_III_L=m_flexor_capri_radialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_III_R=m_flexor_capri_radialis_R.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_III_R=m_flexor_capri_radialis_R.stl
new file mode 100644
index 00000000..7abadd84
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_III_R=m_flexor_capri_radialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_II_L=m_flexor_capri_radialis_L.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_II_L=m_flexor_capri_radialis_L.stl
new file mode 100644
index 00000000..aea532cb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_II_L=m_flexor_capri_radialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_II_R=m_flexor_capri_radialis_R.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_II_R=m_flexor_capri_radialis_R.stl
new file mode 100644
index 00000000..86df04d3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_II_R=m_flexor_capri_radialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_I_L=m_abductor_pollicis_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_I_L=m_abductor_pollicis_longus_L.stl
new file mode 100644
index 00000000..1f696aa3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_I_L=m_abductor_pollicis_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_I_R=m_abductor_pollicis_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_I_R=m_abductor_pollicis_longus_R.stl
new file mode 100644
index 00000000..6c845600
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_I_R=m_abductor_pollicis_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_V_L=m.extensor_capri_ulnaris_L.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_V_L=m.extensor_capri_ulnaris_L.stl
new file mode 100644
index 00000000..39044f0f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_V_L=m.extensor_capri_ulnaris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_V_R=m.extensor_capri_ulnaris_R.stl b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_V_R=m.extensor_capri_ulnaris_R.stl
new file mode 100644
index 00000000..49ff68b4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEOs_metacarpale_V_R=m.extensor_capri_ulnaris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_L=m_biceps_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_biceps_femoris_L.stl
new file mode 100644
index 00000000..73eac3a5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_biceps_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_L=m_rectus_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_rectus_femoris_L.stl
new file mode 100644
index 00000000..9d2991f7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_rectus_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_L=m_sartorius_L(2).stl b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_sartorius_L(2).stl
new file mode 100644
index 00000000..f1b6d718
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_sartorius_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_intermedius_L.stl b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_intermedius_L.stl
new file mode 100644
index 00000000..d4e7f150
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_intermedius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_lat_L.stl b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_lat_L.stl
new file mode 100644
index 00000000..f1a02fd8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_med_L.stl b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_med_L.stl
new file mode 100644
index 00000000..045dcead
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_L=m_vastus_med_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_R=m_biceps_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_biceps_femoris_R.stl
new file mode 100644
index 00000000..421f58bc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_biceps_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_R=m_rectus_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_rectus_femoris_R.stl
new file mode 100644
index 00000000..41ed618f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_rectus_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_R=m_sartorius_R(2).stl b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_sartorius_R(2).stl
new file mode 100644
index 00000000..6b4c30d7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_sartorius_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_intermedius_R.stl b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_intermedius_R.stl
new file mode 100644
index 00000000..4937614f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_intermedius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_lat_R.stl b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_lat_R.stl
new file mode 100644
index 00000000..9da98ebb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_med_R.stl b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_med_R.stl
new file mode 100644
index 00000000..0da9bfb4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPatella_R=m_vastus_med_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..8cb4b0a5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..4fa2569c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_abductor_cruris_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_abductor_cruris_L.stl
new file mode 100644
index 00000000..5b10d888
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_abductor_cruris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_abductor_cruris_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_abductor_cruris_R.stl
new file mode 100644
index 00000000..8852bff8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_abductor_cruris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_longus_L.stl
new file mode 100644
index 00000000..691848c9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_longus_R.stl
new file mode 100644
index 00000000..423a392d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_magnus_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_magnus_L.stl
new file mode 100644
index 00000000..01bac46d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_magnus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_magnus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_magnus_R.stl
new file mode 100644
index 00000000..da109d99
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_adductor_magnus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_biceps_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_biceps_femoris_L.stl
new file mode 100644
index 00000000..77372466
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_biceps_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_biceps_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_biceps_femoris_R.stl
new file mode 100644
index 00000000..4a79c5d4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_biceps_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_capsularis_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_capsularis_L.stl
new file mode 100644
index 00000000..3cfa9575
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_capsularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_capsularis_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_capsularis_R.stl
new file mode 100644
index 00000000..0c12433e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_capsularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_coccygeus_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_coccygeus_L.stl
new file mode 100644
index 00000000..7ca02db9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_coccygeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_coccygeus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_coccygeus_R.stl
new file mode 100644
index 00000000..269c53d9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_coccygeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gemellus_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gemellus_L.stl
new file mode 100644
index 00000000..2fae21a8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gemellus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gemellus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gemellus_R.stl
new file mode 100644
index 00000000..2632a43c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gemellus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_medius_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_medius_L.stl
new file mode 100644
index 00000000..7d5df74e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_medius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_medius_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_medius_R.stl
new file mode 100644
index 00000000..7e35d9d2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_medius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_profundus_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_profundus_minor_L.stl
new file mode 100644
index 00000000..637f81f7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_profundus_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_profundus_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_profundus_minor_R.stl
new file mode 100644
index 00000000..c288c389
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gluteus_profundus_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gracilis_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gracilis_L.stl
new file mode 100644
index 00000000..c2529821
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gracilis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_gracilis_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gracilis_R.stl
new file mode 100644
index 00000000..f8ce1a75
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_gracilis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_iliacus_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_iliacus_L.stl
new file mode 100644
index 00000000..32cf93ea
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_iliacus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_iliacus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_iliacus_R.stl
new file mode 100644
index 00000000..9809860a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_iliacus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_levator_ani_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_levator_ani_L.stl
new file mode 100644
index 00000000..3129c76b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_levator_ani_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_levator_ani_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_levator_ani_R.stl
new file mode 100644
index 00000000..e6e22976
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_levator_ani_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..fe03d12f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..db5b2f51
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..e3594b8a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..deb6aafd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_obliquus_int_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obliquus_int_abdominis_L.stl
new file mode 100644
index 00000000..2d0c1bf4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obliquus_int_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_obliquus_int_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obliquus_int_abdominis_R.stl
new file mode 100644
index 00000000..d7907d30
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obliquus_int_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_ext_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_ext_L.stl
new file mode 100644
index 00000000..37617e96
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_ext_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_ext_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_ext_R.stl
new file mode 100644
index 00000000..ef44785f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_ext_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_int_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_int_L.stl
new file mode 100644
index 00000000..563c0e48
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_int_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_int_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_int_R.stl
new file mode 100644
index 00000000..2a255c3a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_obturatorius_int_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_pectineus_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_pectineus_L.stl
new file mode 100644
index 00000000..bdc3b042
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_pectineus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_pectineus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_pectineus_R.stl
new file mode 100644
index 00000000..5274d6fd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_pectineus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_psoas_minor_L.stl
new file mode 100644
index 00000000..38da1432
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_psoas_minor_R.stl
new file mode 100644
index 00000000..361c1349
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_quadratus_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_quadratus_femoris_L.stl
new file mode 100644
index 00000000..93750d6f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_quadratus_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_quadratus_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_quadratus_femoris_R.stl
new file mode 100644
index 00000000..9e8f6644
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_quadratus_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_abdominis_L.stl
new file mode 100644
index 00000000..95e2f361
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_abdominis_R.stl
new file mode 100644
index 00000000..049381c5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_femoris_L.stl
new file mode 100644
index 00000000..e31a54ad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_femoris_R.stl
new file mode 100644
index 00000000..93a3f521
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_rectus_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_L(2).stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_L(2).stl
new file mode 100644
index 00000000..d574e9d5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_L.stl
new file mode 100644
index 00000000..15111e74
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_R(2).stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_R(2).stl
new file mode 100644
index 00000000..28c6369a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_R.stl
new file mode 100644
index 00000000..3a0113a0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_sartorius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_semimembranosus_L_.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semimembranosus_L_.stl
new file mode 100644
index 00000000..5f32fc33
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semimembranosus_L_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_semimembranosus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semimembranosus_R.stl
new file mode 100644
index 00000000..5cbd04a5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semimembranosus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_semitendinosus_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semitendinosus_L.stl
new file mode 100644
index 00000000..b60194b0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semitendinosus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_semitendinosus_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semitendinosus_R.stl
new file mode 100644
index 00000000..8c428891
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_semitendinosus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_tensor_f.latae_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tensor_f.latae_L.stl
new file mode 100644
index 00000000..9525f7f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tensor_f.latae_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_tensor_f.latae_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tensor_f.latae_R.stl
new file mode 100644
index 00000000..371b2bc3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tensor_f.latae_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..6ed2474b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPelvis=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..25b3e57b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPelvis=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_1=m_extensor_digitorum_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_1=m_extensor_digitorum_longus_L.stl
new file mode 100644
index 00000000..db59e8e2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_1=m_extensor_digitorum_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_1=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_1=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..8794aad6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_1=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_2=m_extensor_digitorum_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_2=m_extensor_digitorum_longus_L.stl
new file mode 100644
index 00000000..53c70f02
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_2=m_extensor_digitorum_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_2=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_2=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..79921a08
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_2=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_3=m_extensor_digitorum_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_3=m_extensor_digitorum_longus_L.stl
new file mode 100644
index 00000000..6f21d68c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_3=m_extensor_digitorum_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_3=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_3=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..e66bb8dd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_3=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_4=m_extensor_digitorum_longus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_4=m_extensor_digitorum_longus_L.stl
new file mode 100644
index 00000000..e6ff865a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_4=m_extensor_digitorum_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_4=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_4=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..f35c4d73
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_L_4=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_1=m_extensor_digitorum_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_1=m_extensor_digitorum_longus_R.stl
new file mode 100644
index 00000000..14ffddc0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_1=m_extensor_digitorum_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_1=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_1=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..568dfa93
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_1=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_2=m_extensor_digitorum_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_2=m_extensor_digitorum_longus_R.stl
new file mode 100644
index 00000000..1e5b0cc3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_2=m_extensor_digitorum_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_2=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_2=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..276e1a3e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_2=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_3=m_extensor_digitorum_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_3=m_extensor_digitorum_longus_R.stl
new file mode 100644
index 00000000..c847b6c5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_3=m_extensor_digitorum_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_3=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_3=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..abba21bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_3=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_4=m_extensor_digitorum_longus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_4=m_extensor_digitorum_longus_R.stl
new file mode 100644
index 00000000..00353c0c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_4=m_extensor_digitorum_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_4=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_4=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..fd9d8c6a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_3_R_4=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_L_4=m_extensor_digitorum_lateralis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_L_4=m_extensor_digitorum_lateralis_L.stl
new file mode 100644
index 00000000..dfa15b9a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_L_4=m_extensor_digitorum_lateralis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanges_B_R_4=m_extensor_digitorum_lateralis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_R_4=m_extensor_digitorum_lateralis_R.stl
new file mode 100644
index 00000000..3e33d381
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanges_B_R_4=m_extensor_digitorum_lateralis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_extensor_digitorum_communis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_extensor_digitorum_communis_L.stl
new file mode 100644
index 00000000..e22279fe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_extensor_digitorum_communis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_extensor_digitorum_lat_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_extensor_digitorum_lat_L.stl
new file mode 100644
index 00000000..e22279fe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_extensor_digitorum_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..818c11a1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_flexor_digitorum_superficialis_.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_flexor_digitorum_superficialis_.stl
new file mode 100644
index 00000000..a4691eb6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_L=m_flexor_digitorum_superficialis_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_extensor_digitorum_communis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_extensor_digitorum_communis_R.stl
new file mode 100644
index 00000000..98c354a2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_extensor_digitorum_communis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_extensor_digitorum_lat_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_extensor_digitorum_lat_R.stl
new file mode 100644
index 00000000..98c354a2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_extensor_digitorum_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..1ffd2895
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_flexor_digitorum_superficialis_.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_flexor_digitorum_superficialis_.stl
new file mode 100644
index 00000000..fc9a93a4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_III_R=m_flexor_digitorum_superficialis_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_extensor_digitorum_communis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_extensor_digitorum_communis_L.stl
new file mode 100644
index 00000000..72f1f0dc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_extensor_digitorum_communis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..d7b41ca4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_flexor_digitorum_superficialis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_flexor_digitorum_superficialis_L.stl
new file mode 100644
index 00000000..1d4e5da4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_L=m_flexor_digitorum_superficialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_extensor_digitorum_communis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_extensor_digitorum_communis_R.stl
new file mode 100644
index 00000000..f37d6439
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_extensor_digitorum_communis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..8e7bea16
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_flexor_digitorum_superficialis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_flexor_digitorum_superficialis_R.stl
new file mode 100644
index 00000000..d9ed4904
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_II_R=m_flexor_digitorum_superficialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_extensor_digitorum_communis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_extensor_digitorum_communis_L.stl
new file mode 100644
index 00000000..a211c606
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_extensor_digitorum_communis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_extensor_digitorum_lat_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_extensor_digitorum_lat_L.stl
new file mode 100644
index 00000000..a211c606
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_extensor_digitorum_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..e1b8a4dc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_flexor_digitorum_superficialis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_flexor_digitorum_superficialis_L.stl
new file mode 100644
index 00000000..39ea4f73
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_L=m_flexor_digitorum_superficialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_extensor_digitorum_communis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_extensor_digitorum_communis_R.stl
new file mode 100644
index 00000000..a65d48d0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_extensor_digitorum_communis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_extensor_digitorum_lat_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_extensor_digitorum_lat_R.stl
new file mode 100644
index 00000000..a65d48d0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_extensor_digitorum_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_flexor_digitorum_superficialis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_flexor_digitorum_superficialis_R.stl
new file mode 100644
index 00000000..517841a5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_IV_R=m_flexor_digitorum_superficialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_I_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_I_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..12f09bcc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_I_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_I_R=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_I_R=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..2443ced3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_I_R=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_extensor_digitorum_communis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_extensor_digitorum_communis_L.stl
new file mode 100644
index 00000000..382b108f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_extensor_digitorum_communis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_extensor_digitorum_lat_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_extensor_digitorum_lat_L.stl
new file mode 100644
index 00000000..382b108f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_extensor_digitorum_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..fd3266fe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_flexor_digitorum_superficialis_L.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_flexor_digitorum_superficialis_L.stl
new file mode 100644
index 00000000..ed694472
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_L=m_flexor_digitorum_superficialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_extensor_digitorum_communis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_extensor_digitorum_communis_R.stl
new file mode 100644
index 00000000..35bedadc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_extensor_digitorum_communis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_extensor_digitorum_lat_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_extensor_digitorum_lat_R.stl
new file mode 100644
index 00000000..35bedadc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_extensor_digitorum_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_profundus_R.001.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_profundus_R.001.stl
new file mode 100644
index 00000000..9948675f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_profundus_R.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..8aaa8341
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_superficialis_R.stl b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_superficialis_R.stl
new file mode 100644
index 00000000..fe5cfd9e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEPhalanx_distalis_digiti_V_R=m_flexor_digitorum_superficialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_L=m.pronator_teres_L.stl b/dm_control/suite/dog_assets/extras/SITERadius_L=m.pronator_teres_L.stl
new file mode 100644
index 00000000..0827dd4e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_L=m.pronator_teres_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_L=m_abductor_pollicis_longus_L.stl b/dm_control/suite/dog_assets/extras/SITERadius_L=m_abductor_pollicis_longus_L.stl
new file mode 100644
index 00000000..d99d53f3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_L=m_abductor_pollicis_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_L=m_biceps_brachii_L.stl b/dm_control/suite/dog_assets/extras/SITERadius_L=m_biceps_brachii_L.stl
new file mode 100644
index 00000000..5f8c95fc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_L=m_biceps_brachii_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_L=m_brachialis_L.stl b/dm_control/suite/dog_assets/extras/SITERadius_L=m_brachialis_L.stl
new file mode 100644
index 00000000..49591c08
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_L=m_brachialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_L=m_extensor_digitorum_lat_L.stl b/dm_control/suite/dog_assets/extras/SITERadius_L=m_extensor_digitorum_lat_L.stl
new file mode 100644
index 00000000..064ed572
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_L=m_extensor_digitorum_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_L=m_supinator_L.stl b/dm_control/suite/dog_assets/extras/SITERadius_L=m_supinator_L.stl
new file mode 100644
index 00000000..98bef5ab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_L=m_supinator_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_R=m.pronator_teres_R.stl b/dm_control/suite/dog_assets/extras/SITERadius_R=m.pronator_teres_R.stl
new file mode 100644
index 00000000..dc91e2da
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_R=m.pronator_teres_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_R=m_abductor_pollicis_longus_R.stl b/dm_control/suite/dog_assets/extras/SITERadius_R=m_abductor_pollicis_longus_R.stl
new file mode 100644
index 00000000..9dbe8bbc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_R=m_abductor_pollicis_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_R=m_biceps_brachii_R.stl b/dm_control/suite/dog_assets/extras/SITERadius_R=m_biceps_brachii_R.stl
new file mode 100644
index 00000000..2d111201
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_R=m_biceps_brachii_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_R=m_brachialis_R.stl b/dm_control/suite/dog_assets/extras/SITERadius_R=m_brachialis_R.stl
new file mode 100644
index 00000000..68a8c70d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_R=m_brachialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_R=m_extensor_digitorum_lat_R.stl b/dm_control/suite/dog_assets/extras/SITERadius_R=m_extensor_digitorum_lat_R.stl
new file mode 100644
index 00000000..859e9d30
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_R=m_extensor_digitorum_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERadius_R=m_supinator_R.stl b/dm_control/suite/dog_assets/extras/SITERadius_R=m_supinator_R.stl
new file mode 100644
index 00000000..f11c7e5f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERadius_R=m_supinator_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_10=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_10=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..4bce42b2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_10=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_10=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..220ec001
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_10=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..205ef07f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_10=m_rectus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_rectus_abdominis_L.stl
new file mode 100644
index 00000000..fb5f07fd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_rectus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_10=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..0d17dc50
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_10=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_11=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_11=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..fb618c3a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_11=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_11=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_latissimus_dors.stl
new file mode 100644
index 00000000..7219b452
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_11=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..ee9516fc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_11=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..80a11224
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_11=m_rectus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_rectus_abdominis_L.stl
new file mode 100644
index 00000000..34c7cd89
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_rectus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_11=m_serratus_dorsalis_caudalis_L1.stl b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_serratus_dorsalis_caudalis_L1.stl
new file mode 100644
index 00000000..0e6d3f6d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_11=m_serratus_dorsalis_caudalis_L1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..4c8c3f09
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_latissimus_dors.stl
new file mode 100644
index 00000000..a7074d81
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..5ee16c06
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..e30bdf41
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=m_obliquus_int_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_obliquus_int_abdominis_L.stl
new file mode 100644
index 00000000..addff7f7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_obliquus_int_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=m_rectus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_rectus_abdominis_L.stl
new file mode 100644
index 00000000..35bb5c93
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_rectus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=m_serratus_dorsalis_caudalis_L2.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_serratus_dorsalis_caudalis_L2.stl
new file mode 100644
index 00000000..b86dd624
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_serratus_dorsalis_caudalis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_12=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..42db6385
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_12=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_13=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_13=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..63809eca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_13=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_13=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..b03ab757
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_13=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..6fb6b955
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_13=m_obliquus_int_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_obliquus_int_abdominis_L.stl
new file mode 100644
index 00000000..8e026511
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_obliquus_int_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_13=m_rectus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_rectus_abdominis_L.stl
new file mode 100644
index 00000000..ab8d74c9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_rectus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_13=m_serratus_dorsalis_caudalis_L3.stl b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_serratus_dorsalis_caudalis_L3.stl
new file mode 100644
index 00000000..3dc88ab9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_serratus_dorsalis_caudalis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_13=m_tranversus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_tranversus_abdominis_L.stl
new file mode 100644
index 00000000..3e936e61
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_13=m_tranversus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_1=Mm_dorsi_scaleni_dors_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_1=Mm_dorsi_scaleni_dors_L.stl
new file mode 100644
index 00000000..08d8d6f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_1=Mm_dorsi_scaleni_dors_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_1=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_1=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..915a97e3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_1=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_1=m_rectus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_1=m_rectus_thoracis_L.stl
new file mode 100644
index 00000000..f1c94bad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_1=m_rectus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_1=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_L_1=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..ed905ed4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_1=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_2=Mm_dorsi_scaleni_dors_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_2=Mm_dorsi_scaleni_dors_L.stl
new file mode 100644
index 00000000..297be63b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_2=Mm_dorsi_scaleni_dors_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_2=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_2=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..01397832
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_2=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_2=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..da4e3537
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_2=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..049b3d81
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_2=m_rectus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_rectus_thoracis_L.stl
new file mode 100644
index 00000000..71d5896e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_rectus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_2=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..47a8ad3d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_2=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_2=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_L_2=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..012b623c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_2=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=Mm_dorsi_scaleni_dors_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=Mm_dorsi_scaleni_dors_L.stl
new file mode 100644
index 00000000..1aed4f0a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=Mm_dorsi_scaleni_dors_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..3d68743b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..66136963
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..dd99d622
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=m_rectus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_rectus_thoracis_L.stl
new file mode 100644
index 00000000..a0c7f08a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_rectus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..6d6e8d44
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..a8c07dfa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_3=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_L_3=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..54bf2993
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_3=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_4=Mm_dorsi_scaleni_dors_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_4=Mm_dorsi_scaleni_dors_L.stl
new file mode 100644
index 00000000..76ee0d6a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_4=Mm_dorsi_scaleni_dors_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_4=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_4=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..6080731f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_4=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_4=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..a3c7f2c5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_4=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..5937405f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_4=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..6688befe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_4=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..c2e30f6b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_4=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_4=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_L_4=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..e41c23c6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_4=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_5=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_5=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..dc39c33b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_5=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_5=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_5=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..5a940b7d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_5=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_5=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..07518175
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_5=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..cad9c0ac
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_5=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..85b4c72f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_5=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..bdf8983b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_5=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_5=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_L_5=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..cca495e3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_5=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_6=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_6=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..aabb0f2f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_6=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_6=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_6=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..f87a6d29
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_6=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_6=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..acbf9474
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_6=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..81a126f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_6=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..3167337f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_6=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..d67f98de
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_6=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_6=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_L_6=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..338e80eb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_6=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_7=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_7=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..089f2668
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_7=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_7=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_7=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..9107418a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_7=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_7=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..a7d1461d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_7=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..99b68cc9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_7=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..3361b3de
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_7=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..d4cff68d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_7=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_7=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_L_7=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..9c47f64d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_7=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_8=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_8=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..09bbda62
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_8=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_8=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..80162a18
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_8=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..f5a51c93
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_8=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..750c01a9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_8=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..28440f31
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_8=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_9=Mm_abdominis_m.obliquus_ext._abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_9=Mm_abdominis_m.obliquus_ext._abdominis_L.stl
new file mode 100644
index 00000000..85a17fb5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_9=Mm_abdominis_m.obliquus_ext._abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_9=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..dc03718d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_9=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..fa75fde6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_9=m_rectus_abdominis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_rectus_abdominis_L.stl
new file mode 100644
index 00000000..96f750f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_rectus_abdominis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_9=m_serratus_dorsalis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_serratus_dorsalis_cranialis_L.stl
new file mode 100644
index 00000000..ef6dfd72
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_serratus_dorsalis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_L_9=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..84148109
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_L_9=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_10=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_10=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..37a9ee0b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_10=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_10=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..c1e58eef
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_10=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..af9bf21b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_10=m_rectus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_rectus_abdominis_R.stl
new file mode 100644
index 00000000..0846d314
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_rectus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_10=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..08d9102a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_10=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_11=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_11=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..f0ef7ad0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_11=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_11=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_latissimus_dors.stl
new file mode 100644
index 00000000..49de6439
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_11=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..7e792d35
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_11=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..b5e3826c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_11=m_rectus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_rectus_abdominis_R.stl
new file mode 100644
index 00000000..6da4c40a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_rectus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_11=m_serratus_dorsalis_caudalis_R1.stl b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_serratus_dorsalis_caudalis_R1.stl
new file mode 100644
index 00000000..ef76fb2d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_11=m_serratus_dorsalis_caudalis_R1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..24b30a03
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_latissimus_dors.stl
new file mode 100644
index 00000000..bc844188
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..b4a59bb7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..04dae27f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=m_obliquus_int_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_obliquus_int_abdominis_R.stl
new file mode 100644
index 00000000..b1302495
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_obliquus_int_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=m_rectus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_rectus_abdominis_R.stl
new file mode 100644
index 00000000..ce35e836
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_rectus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=m_serratus_dorsalis_caudalis_R2.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_serratus_dorsalis_caudalis_R2.stl
new file mode 100644
index 00000000..49936179
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_serratus_dorsalis_caudalis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_12=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..2062bed8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_12=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_13=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_13=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..1e53cc0d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_13=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_13=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..73007b30
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_13=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..d5b5eb18
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_13=m_obliquus_int_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_obliquus_int_abdominis_R.stl
new file mode 100644
index 00000000..5aa64063
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_obliquus_int_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_13=m_rectus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_rectus_abdominis_R.stl
new file mode 100644
index 00000000..db0e8724
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_rectus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_13=m_serratus_dorsalis_caudalis_R3.stl b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_serratus_dorsalis_caudalis_R3.stl
new file mode 100644
index 00000000..42f9d29d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_serratus_dorsalis_caudalis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_13=m_tranversus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_tranversus_abdominis_R.stl
new file mode 100644
index 00000000..c3ce19d7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_13=m_tranversus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_1=Mm_dorsi_scaleni_dors_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_1=Mm_dorsi_scaleni_dors_R.stl
new file mode 100644
index 00000000..870c108f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_1=Mm_dorsi_scaleni_dors_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_1=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_1=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..80584de7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_1=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_1=m_rectus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_1=m_rectus_thoracis_R.stl
new file mode 100644
index 00000000..0da62684
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_1=m_rectus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_1=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_R_1=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..5ff19cf5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_1=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_2=Mm_dorsi_scaleni_dors_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_2=Mm_dorsi_scaleni_dors_R.stl
new file mode 100644
index 00000000..a475b21a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_2=Mm_dorsi_scaleni_dors_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_2=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_2=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..24c81a6f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_2=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_2=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..189ebecf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_2=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..9eeee676
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_2=m_rectus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_rectus_thoracis_R.stl
new file mode 100644
index 00000000..8d1a3e52
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_rectus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_2=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..f556b4d8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_2=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_2=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_R_2=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..72bcf02f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_2=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=Mm_dorsi_scaleni_dors_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=Mm_dorsi_scaleni_dors_R.stl
new file mode 100644
index 00000000..05a4e0e1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=Mm_dorsi_scaleni_dors_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..1fb8d314
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..6f86155a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..9202785e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=m_rectus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_rectus_thoracis_R.stl
new file mode 100644
index 00000000..ed44acaf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_rectus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..5266c77a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..4b72db03
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_3=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_R_3=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..8d1d5077
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_3=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_4=Mm_dorsi_scaleni_dors_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_4=Mm_dorsi_scaleni_dors_R.stl
new file mode 100644
index 00000000..e2f9d319
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_4=Mm_dorsi_scaleni_dors_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_4=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_4=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..a9448d50
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_4=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_4=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..087ab019
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_4=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..20d4e29c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_4=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..b9517f88
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_4=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..385f47aa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_4=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_4=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_R_4=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..a9e4e1ec
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_4=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_5=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_5=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..73de5b8c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_5=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_5=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_5=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..a085ef36
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_5=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_5=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..8b036ac3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_5=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..fc368424
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_5=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..b9df67c6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_5=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..37d1d0cf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_5=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_5=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_R_5=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..062513e3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_5=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..34959ca2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..66775610
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..17f2fe8f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..29f5a7c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..0204b3bd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..dbdc6c09
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=mm_thoracis_m_pectoralis_profundus.001.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=mm_thoracis_m_pectoralis_profundus.001.stl
new file mode 100644
index 00000000..0be26330
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=mm_thoracis_m_pectoralis_profundus.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_6=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITERib_R_6=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..e7064a7f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_6=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_7=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_7=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..b99cb589
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_7=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_7=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_7=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..7b8d1e37
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_7=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_7=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..40ad786e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_7=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..23a00530
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_7=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..ff05a7d0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_7=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..af5e9e5f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_7=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_8=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_8=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..556acede
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_8=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_8=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..7e553563
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_8=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..56ad3472
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_8=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..19ed6afe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_8=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..f746a051
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_8=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_9=Mm_abdominis_m.obliquus_ext._abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_9=Mm_abdominis_m.obliquus_ext._abdominis_R.stl
new file mode 100644
index 00000000..2bcf611a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_9=Mm_abdominis_m.obliquus_ext._abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_9=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..707939c0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_9=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..e4377c05
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_9=m_rectus_abdominis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_rectus_abdominis_R.stl
new file mode 100644
index 00000000..7f26ecb8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_rectus_abdominis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_9=m_serratus_dorsalis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_serratus_dorsalis_cranialis_R.stl
new file mode 100644
index 00000000..b1af05ad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_serratus_dorsalis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITERib_R_9=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..ea1eac5e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITERib_R_9=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L.stl
new file mode 100644
index 00000000..5ae06ee9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L2.stl
new file mode 100644
index 00000000..91c5e329
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L3.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L3.stl
new file mode 100644
index 00000000..f32b2508
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_L3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R.stl
new file mode 100644
index 00000000..503759b2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R2.stl
new file mode 100644
index 00000000..39676dc7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R3.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R3.stl
new file mode 100644
index 00000000..dadd954d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_depressor_caudae_brevis_R3.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_gluteus_superficialis_major_L.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_gluteus_superficialis_major_L.stl
new file mode 100644
index 00000000..015b3f84
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_gluteus_superficialis_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_gluteus_superficialis_major_R.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_gluteus_superficialis_major_R.stl
new file mode 100644
index 00000000..eddd59b5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_gluteus_superficialis_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_L1.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_L1.stl
new file mode 100644
index 00000000..7fab3e6d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_L1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_L2.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_L2.stl
new file mode 100644
index 00000000..233182f0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_R1.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_R1.stl
new file mode 100644
index 00000000..8507af1a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_R1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_R2.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_R2.stl
new file mode 100644
index 00000000..0af60c70
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_brevis_R2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_longus_L.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_longus_L.stl
new file mode 100644
index 00000000..55e6125f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_longus_R.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_longus_R.stl
new file mode 100644
index 00000000..dc5b1c03
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_levator_caudae_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..fed30186
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESacrum=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITESacrum=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..bdc80b65
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESacrum=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L1=m_rhomboideus_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L1=m_rhomboideus_L.stl
new file mode 100644
index 00000000..63a73052
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L1=m_rhomboideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L2=m_rhomboideus_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L2=m_rhomboideus_L.stl
new file mode 100644
index 00000000..2b511182
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L2=m_rhomboideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=Mm_thoracis_serratus_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=Mm_thoracis_serratus_ventr_L.stl
new file mode 100644
index 00000000..9f1d3b5e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=Mm_thoracis_serratus_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=Mm_thoracis_serratus_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=Mm_thoracis_serratus_ventr_R.stl
new file mode 100644
index 00000000..84f280ff
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=Mm_thoracis_serratus_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_biceps_brachii_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_biceps_brachii_L.stl
new file mode 100644
index 00000000..53585ca5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_biceps_brachii_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_coracobrachialis_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_coracobrachialis_L.stl
new file mode 100644
index 00000000..c266e4a2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_coracobrachialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_deltoideus_acromialis_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_deltoideus_acromialis_L.stl
new file mode 100644
index 00000000..e276e27b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_deltoideus_acromialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_deltoideus_scapularis_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_deltoideus_scapularis_L.stl
new file mode 100644
index 00000000..7ad16595
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_deltoideus_scapularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_infraspinatus_L.001.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_infraspinatus_L.001.stl
new file mode 100644
index 00000000..721a9251
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_infraspinatus_L.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_infraspinatus_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_infraspinatus_L.stl
new file mode 100644
index 00000000..b36b2377
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_infraspinatus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_omotransversarius_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_omotransversarius_L.stl
new file mode 100644
index 00000000..9e80d9d8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_omotransversarius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_subscapularis_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_subscapularis_L.stl
new file mode 100644
index 00000000..976cd752
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_subscapularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_supraspinatus_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_supraspinatus_L.stl
new file mode 100644
index 00000000..21cdf19a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_supraspinatus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_teres_major_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_teres_major_L.stl
new file mode 100644
index 00000000..b76b5592
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_teres_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_teres_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_teres_minor_L.stl
new file mode 100644
index 00000000..38e3a903
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_teres_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=m_triceps_brachii_long_L.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_triceps_brachii_long_L.stl
new file mode 100644
index 00000000..aa631468
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=m_triceps_brachii_long_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_L=trapezius.stl b/dm_control/suite/dog_assets/extras/SITEScapula_L=trapezius.stl
new file mode 100644
index 00000000..5bf38f0d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_L=trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R1=m_rhomboideus_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R1=m_rhomboideus_R.stl
new file mode 100644
index 00000000..770fe19c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R1=m_rhomboideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R2=m_rhomboideus_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R2=m_rhomboideus_R.stl
new file mode 100644
index 00000000..648befef
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R2=m_rhomboideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_biceps_brachii_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_biceps_brachii_R.stl
new file mode 100644
index 00000000..2fd6eb5b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_biceps_brachii_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_coracobrachialis_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_coracobrachialis_R.stl
new file mode 100644
index 00000000..cdb37e00
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_coracobrachialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_deltoideus_acromialis_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_deltoideus_acromialis_R.stl
new file mode 100644
index 00000000..bb3c3521
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_deltoideus_acromialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_deltoideus_scapularis_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_deltoideus_scapularis_R.stl
new file mode 100644
index 00000000..f55cebd8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_deltoideus_scapularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_infraspinatus_R.001.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_infraspinatus_R.001.stl
new file mode 100644
index 00000000..e822ff12
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_infraspinatus_R.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_infraspinatus_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_infraspinatus_R.stl
new file mode 100644
index 00000000..2c054209
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_infraspinatus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_omotransversarius_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_omotransversarius_R.stl
new file mode 100644
index 00000000..2c054465
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_omotransversarius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_subscapularis_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_subscapularis_R.stl
new file mode 100644
index 00000000..61ad2992
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_subscapularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_supraspinatus_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_supraspinatus_R.stl
new file mode 100644
index 00000000..b0d1257b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_supraspinatus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_teres_major_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_teres_major_R.stl
new file mode 100644
index 00000000..77fb0bf7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_teres_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_teres_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_teres_minor_R.stl
new file mode 100644
index 00000000..f01c5d00
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_teres_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=m_triceps_brachii_long_R.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_triceps_brachii_long_R.stl
new file mode 100644
index 00000000..da3f3a46
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=m_triceps_brachii_long_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEScapula_R=trapezius.stl b/dm_control/suite/dog_assets/extras/SITEScapula_R=trapezius.stl
new file mode 100644
index 00000000..b3507f3a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEScapula_R=trapezius.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull=m_digastricus_L.stl b/dm_control/suite/dog_assets/extras/SITESkull=m_digastricus_L.stl
new file mode 100644
index 00000000..2ae1c765
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull=m_digastricus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull=m_digastricus_R.stl b/dm_control/suite/dog_assets/extras/SITESkull=m_digastricus_R.stl
new file mode 100644
index 00000000..51b56757
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull=m_digastricus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_lat_R.stl b/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_lat_R.stl
new file mode 100644
index 00000000..4077af59
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_med_L.stl b/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_med_L.stl
new file mode 100644
index 00000000..5fe7407b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_med_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_med_R.stl b/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_med_R.stl
new file mode 100644
index 00000000..de3b8b84
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull=m_pterygoideus_med_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_1=m_masseter_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_1=m_masseter_L.stl
new file mode 100644
index 00000000..c1bd1696
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_1=m_masseter_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_1=m_masseter_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_1=m_masseter_R.stl
new file mode 100644
index 00000000..2ad14199
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_1=m_masseter_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_1=m_temporalis_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_1=m_temporalis_L.stl
new file mode 100644
index 00000000..8b94c40e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_1=m_temporalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_1=m_temporalis_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_1=m_temporalis_R.stl
new file mode 100644
index 00000000..d9390985
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_1=m_temporalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_1L=m_brachiocephalicus.stl b/dm_control/suite/dog_assets/extras/SITESkull_1L=m_brachiocephalicus.stl
new file mode 100644
index 00000000..2e0acbbb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_1L=m_brachiocephalicus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_1R=m_brachiocephalicus.stl b/dm_control/suite/dog_assets/extras/SITESkull_1R=m_brachiocephalicus.stl
new file mode 100644
index 00000000..172bb8e3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_1R=m_brachiocephalicus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_2=m_brachiocephalicus_.stl b/dm_control/suite/dog_assets/extras/SITESkull_2=m_brachiocephalicus_.stl
new file mode 100644
index 00000000..fb81e206
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_2=m_brachiocephalicus_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_2=m_masseter_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_2=m_masseter_L.stl
new file mode 100644
index 00000000..551b8865
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_2=m_masseter_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_2=m_masseter_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_2=m_masseter_R.stl
new file mode 100644
index 00000000..15940393
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_2=m_masseter_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_2=m_temporalis_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_2=m_temporalis_L.stl
new file mode 100644
index 00000000..d09c501d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_2=m_temporalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_2=m_temporalis_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_2=m_temporalis_R.stl
new file mode 100644
index 00000000..7d724c63
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_2=m_temporalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..1c8ff19f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..c8bbdab0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_longus_capitis_L.stl
new file mode 100644
index 00000000..89a606ac
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_longus_capitis_R.stl
new file mode 100644
index 00000000..0bcc8981
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_L(2).stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_L(2).stl
new file mode 100644
index 00000000..62779779
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_L.stl
new file mode 100644
index 00000000..6e86db68
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_R(2).stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_R(2).stl
new file mode 100644
index 00000000..a7c9235b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_R.stl
new file mode 100644
index 00000000..7a4c2e0f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_obliquus_capitis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_major_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_major_L.stl
new file mode 100644
index 00000000..cb8d23df
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_major_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_major_R.stl
new file mode 100644
index 00000000..c83c0a52
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_minor_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_minor_L.stl
new file mode 100644
index 00000000..1efb485a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_minor_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_minor_R.stl
new file mode 100644
index 00000000..dc586d8c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_dorsalis_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_lateralis_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_lateralis_L.stl
new file mode 100644
index 00000000..b2da99aa
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_lateralis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_lateralis_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_lateralis_R.stl
new file mode 100644
index 00000000..0507a5bb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_lateralis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_ventr_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_ventr_L.stl
new file mode 100644
index 00000000..3cee5b33
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_ventr_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_ventr_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_ventr_R.stl
new file mode 100644
index 00000000..c5d22b9d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_rectus_capitis_ventr_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_L1.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_L1.stl
new file mode 100644
index 00000000..a167bda1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_L1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_L2.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_L2.stl
new file mode 100644
index 00000000..1f6886ca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_L2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_R1.001.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_R1.001.stl
new file mode 100644
index 00000000..ee906221
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_R1.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_R1.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_R1.stl
new file mode 100644
index 00000000..8354a3bb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_splenius_R1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_sternocephalicus_L.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_sternocephalicus_L.stl
new file mode 100644
index 00000000..9235a835
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_sternocephalicus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESkull_=m_sternocephalicus_R.stl b/dm_control/suite/dog_assets/extras/SITESkull_=m_sternocephalicus_R.stl
new file mode 100644
index 00000000..a1a13300
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESkull_=m_sternocephalicus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_pectorales_superficiales_1.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_pectorales_superficiales_1.stl
new file mode 100644
index 00000000..6ed76d96
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_pectorales_superficiales_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_pectorales_superficiales_2.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_pectorales_superficiales_2.stl
new file mode 100644
index 00000000..f8972d32
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_pectorales_superficiales_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternocephalicus_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternocephalicus_L.stl
new file mode 100644
index 00000000..e12cb599
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternocephalicus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternocephalicus_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternocephalicus_R.stl
new file mode 100644
index 00000000..294aa5b4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternocephalicus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternohyroideus_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternohyroideus_L.stl
new file mode 100644
index 00000000..c306bc95
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternohyroideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternohyroideus_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternohyroideus_R.stl
new file mode 100644
index 00000000..54648920
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternohyroideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternothyroideus_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternothyroideus_L.stl
new file mode 100644
index 00000000..3d160a2a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternothyroideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternothyroideus_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternothyroideus_R.stl
new file mode 100644
index 00000000..ad73d2be
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=m_sternothyroideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_1=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_1=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..c696405a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_1=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_2=m_pectorales_superficiales_1.stl b/dm_control/suite/dog_assets/extras/SITESternum_2=m_pectorales_superficiales_1.stl
new file mode 100644
index 00000000..e18e0f3f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_2=m_pectorales_superficiales_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_2=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_2=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..4bfd3393
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_2=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_2=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_2=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..b55590b4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_2=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_2=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_2=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..34f1802a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_2=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_3=m_pectorales_superficiales_1.stl b/dm_control/suite/dog_assets/extras/SITESternum_3=m_pectorales_superficiales_1.stl
new file mode 100644
index 00000000..25efb045
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_3=m_pectorales_superficiales_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_3=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_3=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..71d0b4e2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_3=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_3=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_3=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..bcf49355
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_3=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_3=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_3=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..44f2684d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_3=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_4=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_4=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..2c3160c2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_4=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_4=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_4=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..36418fb1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_4=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_4=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_4=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..401e807c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_4=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_5=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_5=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..888e55e5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_5=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_5=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_5=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..8fe72eb0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_5=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_5=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_5=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..8a3d71e0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_5=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_6=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_6=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..69da5e86
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_6=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_6=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_6=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..108ad953
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_6=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_6=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_6=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..100df7ab
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_6=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_7=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_7=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..6a4226d5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_7=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_7=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_7=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..cfd2f935
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_7=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_7=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_7=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..6a858cc6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_7=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_8=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITESternum_8=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..017f1b49
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_8=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_8=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITESternum_8=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..65c90ca6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_8=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITESternum_8=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITESternum_8=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..873900cc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITESternum_8=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-10=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITET-10=m_psoas_major_L.stl
new file mode 100644
index 00000000..e85f2332
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-10=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-10=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITET-10=m_psoas_major_R.stl
new file mode 100644
index 00000000..477335f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-10=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_major_L.stl
new file mode 100644
index 00000000..9d3b4590
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_major_R.stl
new file mode 100644
index 00000000..c62ae2b2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_minor_L.stl
new file mode 100644
index 00000000..f8e5043a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_minor_R.stl
new file mode 100644
index 00000000..950e6e84
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-11=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_major_L.stl
new file mode 100644
index 00000000..4e44bcf7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_major_R.stl
new file mode 100644
index 00000000..c2cb2ff3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_minor_L.stl
new file mode 100644
index 00000000..63b454af
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_minor_R.stl
new file mode 100644
index 00000000..3ae81852
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-12=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_major_L.stl b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_major_L.stl
new file mode 100644
index 00000000..6915bd77
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_major_R.stl b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_major_R.stl
new file mode 100644
index 00000000..3e3317b6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_minor_L.stl b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_minor_L.stl
new file mode 100644
index 00000000..d8424ea2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_minor_R.stl b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_minor_R.stl
new file mode 100644
index 00000000..3c9c8a95
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-13=m_psoas_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-1=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITET-1=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..71c39fcd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-1=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-1=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITET-1=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..3147d44c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-1=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-1=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET-1=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..7154946e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-1=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-1=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET-1=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..ed016622
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-1=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-1=m_longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET-1=m_longus_capitis_L.stl
new file mode 100644
index 00000000..d996c9f2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-1=m_longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-1=m_longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET-1=m_longus_capitis_R.stl
new file mode 100644
index 00000000..71f1cb1b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-1=m_longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-2=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITET-2=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..4de5080c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-2=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-2=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITET-2=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..953413ba
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-2=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-2=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET-2=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..dfee9111
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-2=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-2=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET-2=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..99573480
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-2=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-2=m_longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET-2=m_longus_capitis_L.stl
new file mode 100644
index 00000000..386f03fc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-2=m_longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-2=m_longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET-2=m_longus_capitis_R.stl
new file mode 100644
index 00000000..11562cd6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-2=m_longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-3=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITET-3=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..1ef796d7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-3=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-3=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITET-3=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..dd0c70c5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-3=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-3=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET-3=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..beb5e245
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-3=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-3=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET-3=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..a8a7c58f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-3=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-3=m_longus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET-3=m_longus_capitis_L.stl
new file mode 100644
index 00000000..f41836e4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-3=m_longus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-3=m_longus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET-3=m_longus_capitis_R.stl
new file mode 100644
index 00000000..582222db
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-3=m_longus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-4=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITET-4=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..683773ec
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-4=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-4=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITET-4=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..3b6818b6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-4=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-5=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITET-5=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..679a007b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-5=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-5=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITET-5=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..735158d5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-5=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-6=m.longissimus_cervicis_L.stl b/dm_control/suite/dog_assets/extras/SITET-6=m.longissimus_cervicis_L.stl
new file mode 100644
index 00000000..6b261cad
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-6=m.longissimus_cervicis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET-6=m.longissimus_cervicis_R.stl b/dm_control/suite/dog_assets/extras/SITET-6=m.longissimus_cervicis_R.stl
new file mode 100644
index 00000000..46d43b0b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET-6=m.longissimus_cervicis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..28e30d94
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..8649feca
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..ce70f1fc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..6cca5fb4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_10=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..5622993b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..ba468329
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..9250122a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..e7ab14d5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_11=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..258ffd1d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..075a9c7e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..516401f8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..0d342af0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_12=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..3ef6cd29
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..f0df3054
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..69f9184e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..b5ac6ebb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_13=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_1=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET_1=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..99c96eb3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_1=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_1=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET_1=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..43ee6b1d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_1=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_capitis_L.stl b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_capitis_L.stl
new file mode 100644
index 00000000..e255737d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_capitis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_capitis_R.stl b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_capitis_R.stl
new file mode 100644
index 00000000..6542b798
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_capitis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..72bfe7fd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..6eb1da54
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_2=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_3=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_3=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..f8f75517
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_3=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_3=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_3=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..8dc014ce
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_3=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_3=m_splenius_L.stl b/dm_control/suite/dog_assets/extras/SITET_3=m_splenius_L.stl
new file mode 100644
index 00000000..e14f5fbc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_3=m_splenius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_3=m_splenius_R.stl b/dm_control/suite/dog_assets/extras/SITET_3=m_splenius_R.stl
new file mode 100644
index 00000000..9b6ec812
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_3=m_splenius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_4=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_4=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..90330f35
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_4=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_4=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_4=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..e61ec099
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_4=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_4=m_rhomboideus_L.stl b/dm_control/suite/dog_assets/extras/SITET_4=m_rhomboideus_L.stl
new file mode 100644
index 00000000..d96ea66d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_4=m_rhomboideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_4=m_rhomboideus_R.stl b/dm_control/suite/dog_assets/extras/SITET_4=m_rhomboideus_R.stl
new file mode 100644
index 00000000..c798215a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_4=m_rhomboideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_4=m_splenius_L.stl b/dm_control/suite/dog_assets/extras/SITET_4=m_splenius_L.stl
new file mode 100644
index 00000000..a1e54699
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_4=m_splenius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_4=m_splenius_R.stl b/dm_control/suite/dog_assets/extras/SITET_4=m_splenius_R.stl
new file mode 100644
index 00000000..f1c441e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_4=m_splenius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_5=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_5=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..da7d5140
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_5=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_5=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_5=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..11e5f0c5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_5=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_5=m_rhomboideus_L.stl b/dm_control/suite/dog_assets/extras/SITET_5=m_rhomboideus_L.stl
new file mode 100644
index 00000000..faea8ce0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_5=m_rhomboideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_5=m_rhomboideus_R.stl b/dm_control/suite/dog_assets/extras/SITET_5=m_rhomboideus_R.stl
new file mode 100644
index 00000000..a6d1a60a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_5=m_rhomboideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_6=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_6=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..94e1ca70
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_6=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_6=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_6=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..0cfa1985
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_6=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_6=m_rhomboideus_L.stl b/dm_control/suite/dog_assets/extras/SITET_6=m_rhomboideus_L.stl
new file mode 100644
index 00000000..01fa4ecb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_6=m_rhomboideus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_6=m_rhomboideus_R.stl b/dm_control/suite/dog_assets/extras/SITET_6=m_rhomboideus_R.stl
new file mode 100644
index 00000000..ba8467a2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_6=m_rhomboideus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_7=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_7=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..a9225ad9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_7=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_7=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_7=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..e37a4f60
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_7=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_8=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_8=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..c5b98c88
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_8=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_8=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_8=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..17e055c0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_8=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_lumborum_L.stl b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_lumborum_L.stl
new file mode 100644
index 00000000..b293a4fe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_lumborum_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_lumborum_R.stl b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_lumborum_R.stl
new file mode 100644
index 00000000..8d408cb8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_lumborum_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_thoracis_L.stl
new file mode 100644
index 00000000..e6a53c01
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_thoracis_R.stl
new file mode 100644
index 00000000..30bffa79
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITET_9=m_longissimus_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETarsus_L_I=m_peroneus_longus_L.stl b/dm_control/suite/dog_assets/extras/SITETarsus_L_I=m_peroneus_longus_L.stl
new file mode 100644
index 00000000..3e25d879
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETarsus_L_I=m_peroneus_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETarsus_R_I=m_peroneus_longus_R.stl b/dm_control/suite/dog_assets/extras/SITETarsus_R_I=m_peroneus_longus_R.stl
new file mode 100644
index 00000000..e4bd6349
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETarsus_R_I=m_peroneus_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETarsus_central_L=m_tibialis_cranialis_L.stl b/dm_control/suite/dog_assets/extras/SITETarsus_central_L=m_tibialis_cranialis_L.stl
new file mode 100644
index 00000000..151ad9a1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETarsus_central_L=m_tibialis_cranialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETarsus_central_R=m_tibialis_cranialis_R.stl b/dm_control/suite/dog_assets/extras/SITETarsus_central_R=m_tibialis_cranialis_R.stl
new file mode 100644
index 00000000..44045918
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETarsus_central_R=m_tibialis_cranialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L1=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L1=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..31c55541
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L1=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L2=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L2=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..9c257e9f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L2=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L=m_abductor_cruris_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L=m_abductor_cruris_L.stl
new file mode 100644
index 00000000..14805027
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L=m_abductor_cruris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L=m_biceps_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L=m_biceps_femoris_L.stl
new file mode 100644
index 00000000..68032efc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L=m_biceps_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L=m_gracilis_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L=m_gracilis_L.stl
new file mode 100644
index 00000000..c65d5f38
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L=m_gracilis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L=m_peroneus_longus_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L=m_peroneus_longus_L.stl
new file mode 100644
index 00000000..8cac4ff3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L=m_peroneus_longus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L=m_popliteus_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L=m_popliteus_L.stl
new file mode 100644
index 00000000..825df959
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L=m_popliteus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L=m_sartorius_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L=m_sartorius_L.stl
new file mode 100644
index 00000000..30f92c6f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L=m_sartorius_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_L=m_semitendinosus_L.stl b/dm_control/suite/dog_assets/extras/SITETibia_L=m_semitendinosus_L.stl
new file mode 100644
index 00000000..cb7a6830
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_L=m_semitendinosus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R1=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R1=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..15cb7d62
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R1=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R2=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R2=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..9bf01045
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R2=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R=m_abductor_cruris_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R=m_abductor_cruris_R.stl
new file mode 100644
index 00000000..34d1cbea
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R=m_abductor_cruris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R=m_biceps_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R=m_biceps_femoris_R.stl
new file mode 100644
index 00000000..c5ab18e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R=m_biceps_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R=m_gracilis_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R=m_gracilis_R.stl
new file mode 100644
index 00000000..a3decfbe
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R=m_gracilis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R=m_peroneus_longus_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R=m_peroneus_longus_R.stl
new file mode 100644
index 00000000..d1483719
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R=m_peroneus_longus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R=m_popliteus_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R=m_popliteus_R.stl
new file mode 100644
index 00000000..65f55cd5
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R=m_popliteus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R=m_sartorius_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R=m_sartorius_R.stl
new file mode 100644
index 00000000..a7d329dc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R=m_sartorius_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITETibia_R=m_semitendinosus_R.stl b/dm_control/suite/dog_assets/extras/SITETibia_R=m_semitendinosus_R.stl
new file mode 100644
index 00000000..f18962c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITETibia_R=m_semitendinosus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m.flexor_carpi_ulnaris_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m.flexor_carpi_ulnaris_L.stl
new file mode 100644
index 00000000..1db8820b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m.flexor_carpi_ulnaris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_anconeus_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_anconeus_L.stl
new file mode 100644
index 00000000..61400cdd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_anconeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_biceps_brachii_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_biceps_brachii_L.stl
new file mode 100644
index 00000000..782a3f6f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_biceps_brachii_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_extensor_digitorum_lat_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_extensor_digitorum_lat_L.stl
new file mode 100644
index 00000000..6f86bdcd
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_extensor_digitorum_lat_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..a4519ea0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_tensor_f.antebrachii_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_tensor_f.antebrachii_L.stl
new file mode 100644
index 00000000..fb4fab09
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_tensor_f.antebrachii_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_accessory_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_accessory_L.stl
new file mode 100644
index 00000000..772a514b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_accessory_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_lateral_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_lateral_L.stl
new file mode 100644
index 00000000..e3720c1b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_lateral_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_long_L(2).stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_long_L(2).stl
new file mode 100644
index 00000000..0b130e66
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_long_L(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_medial_L.stl b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_medial_L.stl
new file mode 100644
index 00000000..2ebb6430
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_L=m_triceps_brachii_medial_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m.flexor_carpi_ulnaris_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m.flexor_carpi_ulnaris_R.stl
new file mode 100644
index 00000000..6cbe39a8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m.flexor_carpi_ulnaris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_anconeus_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_anconeus_R.stl
new file mode 100644
index 00000000..d22e59cf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_anconeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_biceps_brachii_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_biceps_brachii_R.stl
new file mode 100644
index 00000000..34254a09
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_biceps_brachii_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_extensor_digitorum_lat_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_extensor_digitorum_lat_R.stl
new file mode 100644
index 00000000..e5870e63
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_extensor_digitorum_lat_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..c187c0cc
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_tensor_f.antebrachii_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_tensor_f.antebrachii_R.stl
new file mode 100644
index 00000000..69dcfb39
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_tensor_f.antebrachii_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_accessory_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_accessory_R.stl
new file mode 100644
index 00000000..c7da4050
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_accessory_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_lateral_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_lateral_R.stl
new file mode 100644
index 00000000..9ad15deb
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_lateral_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_long_R(2).stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_long_R(2).stl
new file mode 100644
index 00000000..51503fb8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_long_R(2).stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_medial_R.stl b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_medial_R.stl
new file mode 100644
index 00000000..8e663c08
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEUlna_R=m_triceps_brachii_medial_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEXiphoid_cartilage=m_transverses_thoracis_L.stl b/dm_control/suite/dog_assets/extras/SITEXiphoid_cartilage=m_transverses_thoracis_L.stl
new file mode 100644
index 00000000..69101c48
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEXiphoid_cartilage=m_transverses_thoracis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEXiphoid_cartilage=m_transverses_thoracis_R.stl b/dm_control/suite/dog_assets/extras/SITEXiphoid_cartilage=m_transverses_thoracis_R.stl
new file mode 100644
index 00000000..ee180ca4
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEXiphoid_cartilage=m_transverses_thoracis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L1=m_brachiocephalicus_.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L1=m_brachiocephalicus_.stl
new file mode 100644
index 00000000..db933434
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L1=m_brachiocephalicus_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L2=m_brachiocephalicus_.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L2=m_brachiocephalicus_.stl
new file mode 100644
index 00000000..648052d3
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L2=m_brachiocephalicus_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m.extensor_capri_ulnaris_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m.extensor_capri_ulnaris_L.stl
new file mode 100644
index 00000000..67cb7fae
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m.extensor_capri_ulnaris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m.pronator_teres_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m.pronator_teres_L.stl
new file mode 100644
index 00000000..04f7582b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m.pronator_teres_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_anconeus_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_anconeus_L.stl
new file mode 100644
index 00000000..6c6d40b6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_anconeus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_brachialis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_brachialis_L.stl
new file mode 100644
index 00000000..592e7224
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_brachialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_coracobrachialis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_coracobrachialis_L.stl
new file mode 100644
index 00000000..1d95f84c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_coracobrachialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_deltoideus_acromialis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_deltoideus_acromialis_L.stl
new file mode 100644
index 00000000..7eae6a65
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_deltoideus_acromialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_deltoideus_scapularis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_deltoideus_scapularis_L.stl
new file mode 100644
index 00000000..c85cafe8
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_deltoideus_scapularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_extensor_capri_radialis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_extensor_capri_radialis_L.stl
new file mode 100644
index 00000000..a057514c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_extensor_capri_radialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_extensor_digitorum_communis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_extensor_digitorum_communis_L.stl
new file mode 100644
index 00000000..9efae47e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_extensor_digitorum_communis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_capri_radialis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_capri_radialis_L.stl
new file mode 100644
index 00000000..7b986e1c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_capri_radialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_digitorum_profundus_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_digitorum_profundus_L.stl
new file mode 100644
index 00000000..2ec84390
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_digitorum_profundus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_digitorum_superficialis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_digitorum_superficialis_L.stl
new file mode 100644
index 00000000..3bf196f1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_flexor_digitorum_superficialis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_latissimus_dors.stl
new file mode 100644
index 00000000..47d46e54
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_pectorales_superficiales_1.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_pectorales_superficiales_1.stl
new file mode 100644
index 00000000..0e70900a
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_pectorales_superficiales_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_pectorales_superficiales_2.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_pectorales_superficiales_2.stl
new file mode 100644
index 00000000..a1a376bf
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_pectorales_superficiales_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_subscapularis_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_subscapularis_L.stl
new file mode 100644
index 00000000..e3ed580d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_subscapularis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_supinator_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_supinator_L.stl
new file mode 100644
index 00000000..b69f6b78
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_supinator_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_supraspinatus_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_supraspinatus_L.stl
new file mode 100644
index 00000000..4fb0acec
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_supraspinatus_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_teres_major_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_teres_major_L.stl
new file mode 100644
index 00000000..ef6985b9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_teres_major_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_teres_minor_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_teres_minor_L.stl
new file mode 100644
index 00000000..da94bfd7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_teres_minor_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_accessory_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_accessory_L.stl
new file mode 100644
index 00000000..1ef15965
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_accessory_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_lateral_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_lateral_L.stl
new file mode 100644
index 00000000..e47b5c2e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_lateral_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_medial_L.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_medial_L.stl
new file mode 100644
index 00000000..cbabdbba
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=m_triceps_brachii_medial_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_L=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_L=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..6b35ee30
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_L=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R1=m_brachiocephalicus_.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R1=m_brachiocephalicus_.stl
new file mode 100644
index 00000000..13d66170
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R1=m_brachiocephalicus_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R2=m_brachiocephalicus_.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R2=m_brachiocephalicus_.stl
new file mode 100644
index 00000000..a1a89e9c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R2=m_brachiocephalicus_.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m.extensor_capri_ulnaris_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m.extensor_capri_ulnaris_R.stl
new file mode 100644
index 00000000..2672bfe1
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m.extensor_capri_ulnaris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m.pronator_teres_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m.pronator_teres_R.stl
new file mode 100644
index 00000000..186dd4b0
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m.pronator_teres_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_anconeus_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_anconeus_R.stl
new file mode 100644
index 00000000..0f20237c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_anconeus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_brachialis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_brachialis_R.stl
new file mode 100644
index 00000000..bbaedda2
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_brachialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_coracobrachialis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_coracobrachialis_R.stl
new file mode 100644
index 00000000..5a0aad76
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_coracobrachialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_deltoideus_acromialis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_deltoideus_acromialis_R.stl
new file mode 100644
index 00000000..e397d2da
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_deltoideus_acromialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_deltoideus_scapularis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_deltoideus_scapularis_R.stl
new file mode 100644
index 00000000..b08dbb52
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_deltoideus_scapularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_extensor_capri_radialis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_extensor_capri_radialis_R.stl
new file mode 100644
index 00000000..2b49857b
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_extensor_capri_radialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_extensor_digitorum_communis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_extensor_digitorum_communis_R.stl
new file mode 100644
index 00000000..b445983d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_extensor_digitorum_communis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_capri_radialis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_capri_radialis_R.stl
new file mode 100644
index 00000000..0bf9ad46
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_capri_radialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_digitorum_profundus_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_digitorum_profundus_R.stl
new file mode 100644
index 00000000..480691c7
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_digitorum_profundus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_digitorum_superficialis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_digitorum_superficialis_R.stl
new file mode 100644
index 00000000..95c4f425
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_flexor_digitorum_superficialis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_latissimus_dors.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_latissimus_dors.stl
new file mode 100644
index 00000000..71e3bac9
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_latissimus_dors.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_pectorales_superficiales_1.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_pectorales_superficiales_1.stl
new file mode 100644
index 00000000..c9ed0393
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_pectorales_superficiales_1.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_pectorales_superficiales_2.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_pectorales_superficiales_2.stl
new file mode 100644
index 00000000..4cafd571
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_pectorales_superficiales_2.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_subscapularis_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_subscapularis_R.stl
new file mode 100644
index 00000000..9804b28d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_subscapularis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_supinator_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_supinator_R.stl
new file mode 100644
index 00000000..8804fc6d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_supinator_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_supraspinatus_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_supraspinatus_R.stl
new file mode 100644
index 00000000..a13ff8e6
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_supraspinatus_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_teres_major_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_teres_major_R.stl
new file mode 100644
index 00000000..ed578664
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_teres_major_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_teres_minor_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_teres_minor_R.stl
new file mode 100644
index 00000000..ecd881ce
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_teres_minor_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_accessory_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_accessory_R.stl
new file mode 100644
index 00000000..4034e503
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_accessory_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_lateral_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_lateral_R.stl
new file mode 100644
index 00000000..0d754743
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_lateral_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_medial_R.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_medial_R.stl
new file mode 100644
index 00000000..f152835f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=m_triceps_brachii_medial_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEhumerus_R=mm_thoracis_m_pectoralis_profundus.stl b/dm_control/suite/dog_assets/extras/SITEhumerus_R=mm_thoracis_m_pectoralis_profundus.stl
new file mode 100644
index 00000000..e178490c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEhumerus_R=mm_thoracis_m_pectoralis_profundus.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEibula_L=m_tibialis_caudalis_L.001.stl b/dm_control/suite/dog_assets/extras/SITEibula_L=m_tibialis_caudalis_L.001.stl
new file mode 100644
index 00000000..a4ec277d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEibula_L=m_tibialis_caudalis_L.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEibula_L=m_tibialis_caudalis_L.stl b/dm_control/suite/dog_assets/extras/SITEibula_L=m_tibialis_caudalis_L.stl
new file mode 100644
index 00000000..5c566f6c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEibula_L=m_tibialis_caudalis_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEibula_R=m_tibialis_caudalis_R.001.stl b/dm_control/suite/dog_assets/extras/SITEibula_R=m_tibialis_caudalis_R.001.stl
new file mode 100644
index 00000000..f6be2133
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEibula_R=m_tibialis_caudalis_R.001.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEibula_R=m_tibialis_caudalis_R.stl b/dm_control/suite/dog_assets/extras/SITEibula_R=m_tibialis_caudalis_R.stl
new file mode 100644
index 00000000..8f43ab47
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEibula_R=m_tibialis_caudalis_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEm_biceps_femoris_tendon_L=m_biceps_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEm_biceps_femoris_tendon_L=m_biceps_femoris_L.stl
new file mode 100644
index 00000000..3603d47f
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEm_biceps_femoris_tendon_L=m_biceps_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEm_biceps_femoris_tendon_R=m_biceps_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEm_biceps_femoris_tendon_R=m_biceps_femoris_R.stl
new file mode 100644
index 00000000..5840e399
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEm_biceps_femoris_tendon_R=m_biceps_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEsacrotuberous_lig_L=m_biceps_femoris_L.stl b/dm_control/suite/dog_assets/extras/SITEsacrotuberous_lig_L=m_biceps_femoris_L.stl
new file mode 100644
index 00000000..f689e572
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEsacrotuberous_lig_L=m_biceps_femoris_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEsacrotuberous_lig_R=m_biceps_femoris_R.stl b/dm_control/suite/dog_assets/extras/SITEsacrotuberous_lig_R=m_biceps_femoris_R.stl
new file mode 100644
index 00000000..3ddf0a1c
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEsacrotuberous_lig_R=m_biceps_femoris_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEvastus_lat_ander_fascia_lata_L=m_tensor_f.latae_L.stl b/dm_control/suite/dog_assets/extras/SITEvastus_lat_ander_fascia_lata_L=m_tensor_f.latae_L.stl
new file mode 100644
index 00000000..3833157e
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEvastus_lat_ander_fascia_lata_L=m_tensor_f.latae_L.stl differ
diff --git a/dm_control/suite/dog_assets/extras/SITEvastus_lat_ander_fascia_lata_R=m_tensor_f.latae_R.stl b/dm_control/suite/dog_assets/extras/SITEvastus_lat_ander_fascia_lata_R=m_tensor_f.latae_R.stl
new file mode 100644
index 00000000..7bb55f60
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/SITEvastus_lat_ander_fascia_lata_R=m_tensor_f.latae_R.stl differ
diff --git a/dm_control/suite/dog_assets/extras/skin_texture_highres.png b/dm_control/suite/dog_assets/extras/skin_texture_highres.png
new file mode 100644
index 00000000..3cbeb84d
Binary files /dev/null and b/dm_control/suite/dog_assets/extras/skin_texture_highres.png differ
diff --git a/dm_control/suite/dog_assets/skin_texture.png b/dm_control/suite/dog_assets/skin_texture.png
new file mode 100644
index 00000000..f2d05093
Binary files /dev/null and b/dm_control/suite/dog_assets/skin_texture.png differ
diff --git a/dm_control/suite/dog_assets/tennis_ball.png b/dm_control/suite/dog_assets/tennis_ball.png
new file mode 100644
index 00000000..2200b591
Binary files /dev/null and b/dm_control/suite/dog_assets/tennis_ball.png differ
diff --git a/dm_control/suite/explore.py b/dm_control/suite/explore.py
new file mode 100644
index 00000000..accb0281
--- /dev/null
+++ b/dm_control/suite/explore.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Control suite environments explorer."""
+
+
+from absl import app
+from absl import flags
+from dm_control import suite
+from dm_control.suite.wrappers import action_noise
+
+from dm_control import viewer
+
+
+_ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS]
+
+flags.DEFINE_enum('environment_name', None, _ALL_NAMES,
+ 'Optional \'domain_name.task_name\' pair specifying the '
+ 'environment to load. If unspecified a prompt will appear to '
+ 'select one.')
+flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
+flags.DEFINE_bool('visualize_reward', True,
+ 'Whether to vary the colors of geoms according to the '
+ 'current reward value.')
+flags.DEFINE_float('action_noise', 0.,
+ 'Standard deviation of Gaussian noise to apply to actions, '
+ 'expressed as a fraction of the max-min range for each '
+ 'action dimension. Defaults to 0, i.e. no noise.')
+FLAGS = flags.FLAGS
+
+
+def prompt_environment_name(prompt, values):
+ environment_name = None
+ while not environment_name:
+ environment_name = input(prompt)
+ if not environment_name or values.index(environment_name) < 0:
+ print('"%s" is not a valid environment name.' % environment_name)
+ environment_name = None
+ return environment_name
+
+
+def main(argv):
+ del argv
+ environment_name = FLAGS.environment_name
+ if environment_name is None:
+ print('\n '.join(['Available environments:'] + _ALL_NAMES))
+ environment_name = prompt_environment_name(
+ 'Please select an environment name: ', _ALL_NAMES)
+
+ index = _ALL_NAMES.index(environment_name)
+ domain_name, task_name = suite.ALL_TASKS[index]
+
+ task_kwargs = {}
+ if not FLAGS.timeout:
+ task_kwargs['time_limit'] = float('inf')
+
+ def loader():
+ env = suite.load(
+ domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs)
+ env.task.visualize_reward = FLAGS.visualize_reward
+ if FLAGS.action_noise > 0:
+ env = action_noise.Wrapper(env, scale=FLAGS.action_noise)
+ return env
+
+ viewer.launch(loader)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/dm_control/suite/finger.py b/dm_control/suite/finger.py
index fc5b9632..20ef2c81 100644
--- a/dm_control/suite/finger.py
+++ b/dm_control/suite/finger.py
@@ -15,23 +15,15 @@
"""Finger Domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
-
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
_DEFAULT_TIME_LIMIT = 20 # (seconds)
_CONTROL_TIMESTEP = .02 # (seconds)
@@ -55,30 +47,38 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Spin task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Spin(random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def turn_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def turn_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns the easy Turn task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Turn(target_radius=_EASY_TARGET_SIZE, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def turn_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def turn_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns the hard Turn task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Turn(target_radius=_HARD_TARGET_SIZE, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -90,7 +90,7 @@ def touch(self):
def hinge_velocity(self):
"""Returns the velocity of the hinge joint."""
- return self.named.data.sensordata['hinge_velocity']
+ return self.named.data.sensordata['hinge_velocity'][0]
def tip_position(self):
"""Returns the (x,z) position of the tip relative to the hinge."""
@@ -134,13 +134,14 @@ def __init__(self, random=None):
integer seed for creating a new `RandomState`, or None to select a seed
automatically (default).
"""
- super(Spin, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
physics.named.model.site_rgba['target', 3] = 0
physics.named.model.site_rgba['tip', 3] = 0
physics.named.model.dof_damping['hinge'] = .03
_set_random_joint_angles(physics, self.random)
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns state and touch sensors, and target info."""
@@ -168,7 +169,7 @@ def __init__(self, target_radius, random=None):
automatically (default).
"""
self._target_radius = target_radius
- super(Turn, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
target_angle = self.random.uniform(-np.pi, np.pi)
@@ -181,6 +182,8 @@ def initialize_episode(self, physics):
_set_random_joint_angles(physics, self.random)
+ super().initialize_episode(physics)
+
def get_observation(self, physics):
"""Returns state, touch sensors, and target info."""
obs = collections.OrderedDict()
@@ -198,7 +201,7 @@ def get_reward(self, physics):
def _set_random_joint_angles(physics, random, max_attempts=1000):
"""Sets the joints to a random collision-free state."""
- for _ in xrange(max_attempts):
+ for _ in range(max_attempts):
randomizers.randomize_limited_and_rotational_joints(physics, random)
# Check for collisions.
physics.after_reset()
diff --git a/dm_control/suite/finger.xml b/dm_control/suite/finger.xml
index 4692bdff..3b359864 100644
--- a/dm_control/suite/finger.xml
+++ b/dm_control/suite/finger.xml
@@ -13,7 +13,7 @@
-
+
diff --git a/dm_control/suite/fish.py b/dm_control/suite/fish.py
index 651c6aa2..6711595e 100644
--- a/dm_control/suite/fish.py
+++ b/dm_control/suite/fish.py
@@ -15,21 +15,14 @@
"""Fish Domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
@@ -51,21 +44,26 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def upright(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def upright(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns the Fish Upright task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Upright(random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def swim(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def swim(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Fish Swim task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Swim(random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -105,7 +103,7 @@ def __init__(self, random=None):
integer seed for creating a new `RandomState`, or None to select a seed
automatically.
"""
- super(Upright, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Randomizes the tail and fin angles and the orientation of the Fish."""
@@ -115,6 +113,7 @@ def initialize_episode(self, physics):
physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
# Hide the target. It's irrelevant for this task.
physics.named.model.geom_rgba['target', 3] = 0
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of joint angles, velocities and uprightness."""
@@ -140,7 +139,7 @@ def __init__(self, random=None):
integer seed for creating a new `RandomState`, or None to select a seed
automatically (default).
"""
- super(Swim, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
@@ -153,6 +152,7 @@ def initialize_episode(self, physics):
physics.named.model.geom_pos['target', 'x'] = self.random.uniform(-.4, .4)
physics.named.model.geom_pos['target', 'y'] = self.random.uniform(-.4, .4)
physics.named.model.geom_pos['target', 'z'] = self.random.uniform(.1, .3)
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of joints, target direction and velocities."""
diff --git a/dm_control/suite/hopper.py b/dm_control/suite/hopper.py
index e9f08376..5ed5e09d 100644
--- a/dm_control/suite/hopper.py
+++ b/dm_control/suite/hopper.py
@@ -15,14 +15,8 @@
"""Hopper domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -30,7 +24,6 @@
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
@@ -54,21 +47,25 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns a Hopper that strives to stand upright, balancing its pose."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Hopper(hopping=False, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def hop(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def hop(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns a Hopper that strives to hop forward."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Hopper(hopping=True, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -81,7 +78,7 @@ def height(self):
def speed(self):
"""Returns horizontal speed of the Hopper."""
- return self.named.data.subtree_linvel['torso', 'x']
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
def touch(self):
"""Returns the signals from two foot touch sensors."""
@@ -102,18 +99,19 @@ def __init__(self, hopping, random=None):
automatically (default).
"""
self._hopping = hopping
- super(Hopper, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
self._timeout_progress = 0
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of positions, velocities and touch sensors."""
obs = collections.OrderedDict()
# Ignores horizontal position to maintain translational invariance:
- obs['position'] = physics.data.qpos[1:]
+ obs['position'] = physics.data.qpos[1:].copy()
obs['velocity'] = physics.velocity()
obs['touch'] = physics.touch()
return obs
diff --git a/dm_control/suite/hopper.xml b/dm_control/suite/hopper.xml
index c97c4bc6..84ad72eb 100644
--- a/dm_control/suite/hopper.xml
+++ b/dm_control/suite/hopper.xml
@@ -52,6 +52,7 @@
+
diff --git a/dm_control/suite/humanoid.py b/dm_control/suite/humanoid.py
index 814f29bc..b3dffa9a 100644
--- a/dm_control/suite/humanoid.py
+++ b/dm_control/suite/humanoid.py
@@ -15,14 +15,8 @@
"""Humanoid Domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -30,7 +24,6 @@
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
_DEFAULT_TIME_LIMIT = 25
@@ -53,39 +46,48 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Stand task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Humanoid(move_speed=0, pure_state=False, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Walk task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Humanoid(move_speed=_WALK_SPEED, pure_state=False, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Run task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Humanoid(move_speed=_RUN_SPEED, pure_state=False, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add()
-def run_pure_state(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def run_pure_state(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns the Run task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Humanoid(move_speed=_RUN_SPEED, pure_state=True, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -101,11 +103,11 @@ def head_height(self):
def center_of_mass_position(self):
"""Returns position of the center-of-mass."""
- return self.named.data.subtree_com['torso']
+ return self.named.data.subtree_com['torso'].copy()
def center_of_mass_velocity(self):
"""Returns the velocity of the center-of-mass."""
- return self.named.data.subtree_linvel['torso']
+ return self.named.data.sensordata['torso_subtreelinvel'].copy()
def torso_vertical_orientation(self):
"""Returns the z-projection of the torso orientation matrix."""
@@ -113,7 +115,7 @@ def torso_vertical_orientation(self):
def joint_angles(self):
"""Returns the state without global orientation or position."""
- return self.data.qpos[7:] # Skip the 7 DoFs of the free root joint.
+ return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
def extremities(self):
"""Returns end effector positions in egocentric frame."""
@@ -145,14 +147,11 @@ def __init__(self, move_speed, pure_state, random=None):
"""
self._move_speed = move_speed
self._pure_state = pure_state
- super(Humanoid, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode.
- In 'standing' mode, use initial orientation and small velocities.
- In 'random' mode, randomize joint angles and let fall to the floor.
-
Args:
physics: An instance of `Physics`.
@@ -164,6 +163,7 @@ def initialize_episode(self, physics):
# Check for collisions.
physics.after_reset()
penetrating = physics.data.ncon > 0
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns either the pure state or a set of egocentric features."""
diff --git a/dm_control/suite/humanoid.xml b/dm_control/suite/humanoid.xml
index de4b3958..32b84c52 100644
--- a/dm_control/suite/humanoid.xml
+++ b/dm_control/suite/humanoid.xml
@@ -8,12 +8,9 @@
-
-
-
-
+
-
+
@@ -21,6 +18,13 @@
+
+
+
+
+
+
+
@@ -31,84 +35,99 @@
-
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
-
-
-
+
+
+
+
-
-
+
+
-
-
-
-
-
-
+
+
+
+
+
+
+
-
-
+
+
-
-
-
+
+
+
+
-
-
+
+
-
-
-
-
-
-
+
+
+
+
+
+
+
-
+
-
+
+
-
+
+
-
+
+
-
+
-
+
+
-
+
+
-
+
+
@@ -140,19 +159,43 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/humanoid_CMU.py b/dm_control/suite/humanoid_CMU.py
index 1d6eec20..1a04e9d7 100644
--- a/dm_control/suite/humanoid_CMU.py
+++ b/dm_control/suite/humanoid_CMU.py
@@ -15,14 +15,8 @@
"""Humanoid_CMU Domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -30,7 +24,6 @@
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
_DEFAULT_TIME_LIMIT = 20
@@ -52,21 +45,36 @@ def get_model_and_assets():
@SUITE.add()
-def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Stand task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = HumanoidCMU(move_speed=0, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = HumanoidCMU(move_speed=_WALK_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add()
-def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Run task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = HumanoidCMU(move_speed=_RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -86,7 +94,7 @@ def center_of_mass_position(self):
def center_of_mass_velocity(self):
"""Returns the velocity of the center-of-mass."""
- return self.named.data.subtree_linvel['thorax']
+ return self.named.data.sensordata['thorax_subtreelinvel'].copy()
def torso_vertical_orientation(self):
"""Returns the z-projection of the thorax orientation matrix."""
@@ -94,7 +102,7 @@ def torso_vertical_orientation(self):
def joint_angles(self):
"""Returns the state without global orientation or position."""
- return self.data.qpos[7:] # Skip the 7 DoFs of the free root joint.
+ return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint.
def extremities(self):
"""Returns end effector positions in egocentric frame."""
@@ -123,7 +131,7 @@ def __init__(self, move_speed, random=None):
automatically (default).
"""
self._move_speed = move_speed
- super(HumanoidCMU, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets a random collision-free configuration at the start of each episode.
@@ -138,6 +146,7 @@ def initialize_episode(self, physics):
# Check for collisions.
physics.after_reset()
penetrating = physics.data.ncon > 0
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns a set of egocentric features."""
diff --git a/dm_control/suite/humanoid_CMU.xml b/dm_control/suite/humanoid_CMU.xml
index 238d6110..9a41a166 100644
--- a/dm_control/suite/humanoid_CMU.xml
+++ b/dm_control/suite/humanoid_CMU.xml
@@ -120,6 +120,7 @@
+
@@ -172,7 +173,7 @@
-
+
@@ -275,6 +276,7 @@
+
diff --git a/dm_control/suite/tests/loader_test.py b/dm_control/suite/loader_test.py
similarity index 91%
rename from dm_control/suite/tests/loader_test.py
rename to dm_control/suite/loader_test.py
index cbce4f50..b0f6df1e 100644
--- a/dm_control/suite/tests/loader_test.py
+++ b/dm_control/suite/loader_test.py
@@ -15,14 +15,7 @@
"""Tests for the dm_control.suite loader."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
from absl.testing import absltest
-
from dm_control import suite
from dm_control.rl import control
diff --git a/dm_control/suite/lqr.py b/dm_control/suite/lqr.py
index 6a021354..967af1d1 100644
--- a/dm_control/suite/lqr.py
+++ b/dm_control/suite/lqr.py
@@ -15,28 +15,19 @@
"""Procedurally generated LQR domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-
import os
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import xml_tools
-
from lxml import etree
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
-from dm_control.utils import resources
+from dm_control.utils import io as resources
_DEFAULT_TIME_LIMIT = float('inf')
_CONTROL_COST_COEF = 0.1
@@ -60,26 +51,31 @@ def get_model_and_assets(n_bodies, n_actuators, random):
@SUITE.add()
-def lqr_2_1(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def lqr_2_1(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns an LQR environment with 2 bodies of which the first is actuated."""
return _make_lqr(n_bodies=2,
n_actuators=1,
control_cost_coef=_CONTROL_COST_COEF,
time_limit=time_limit,
- random=random)
+ random=random,
+ environment_kwargs=environment_kwargs)
@SUITE.add()
-def lqr_6_2(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def lqr_6_2(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns an LQR environment with 6 bodies of which first 2 are actuated."""
return _make_lqr(n_bodies=6,
n_actuators=2,
control_cost_coef=_CONTROL_COST_COEF,
time_limit=time_limit,
- random=random)
+ random=random,
+ environment_kwargs=environment_kwargs)
-def _make_lqr(n_bodies, n_actuators, control_cost_coef, time_limit, random):
+def _make_lqr(n_bodies, n_actuators, control_cost_coef, time_limit, random,
+ environment_kwargs):
"""Returns a LQR environment.
Args:
@@ -91,6 +87,8 @@ def _make_lqr(n_bodies, n_actuators, control_cost_coef, time_limit, random):
random: Either an existing `numpy.random.RandomState` instance, an
integer seed for creating a new `RandomState`, or None to select a seed
automatically.
+ environment_kwargs: A `dict` specifying keyword arguments for the
+ environment, or None.
Returns:
A LQR environment with `n_bodies` bodies of which first `n_actuators` are
@@ -104,7 +102,9 @@ def _make_lqr(n_bodies, n_actuators, control_cost_coef, time_limit, random):
random=random)
physics = Physics.from_xml_string(model_string, assets=assets)
task = LQRLevel(control_cost_coef, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ **environment_kwargs)
def _make_body(body_id, stiffness_range, damping_range, random):
@@ -165,13 +165,13 @@ def _make_model(n_bodies,
raise ValueError('At most 1 actuator per body.')
file_path = os.path.join(os.path.dirname(__file__), 'lqr.xml')
- xml_file = resources.GetResourceAsFile(file_path)
- mjcf = xml_tools.parse(xml_file)
+ with resources.GetResourceAsFile(file_path) as xml_file:
+ mjcf = xml_tools.parse(xml_file)
parent = mjcf.find('./worldbody')
actuator = etree.SubElement(mjcf.getroot(), 'actuator')
tendon = etree.SubElement(mjcf.getroot(), 'tendon')
- for body in xrange(n_bodies):
+ for body in range(n_bodies):
# Inserting body.
child = _make_body(body, stiffness_range, damping_range, random)
site_name = 'site_{}'.format(body)
@@ -229,7 +229,7 @@ def __init__(self, control_cost_coef, random=None):
raise ValueError('control_cost_coef must be positive.')
self._control_cost_coef = control_cost_coef
- super(LQRLevel, self).__init__(random=random)
+ super().__init__(random=random)
@property
def control_cost_coef(self):
@@ -240,6 +240,7 @@ def initialize_episode(self, physics):
ndof = physics.model.nq
unit = self.random.randn(ndof)
physics.data.qpos[:] = np.sqrt(2) * unit / np.linalg.norm(unit)
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of the state."""
diff --git a/dm_control/suite/lqr_solver.py b/dm_control/suite/lqr_solver.py
index 4cc6ec63..e27738cc 100644
--- a/dm_control/suite/lqr_solver.py
+++ b/dm_control/suite/lqr_solver.py
@@ -19,67 +19,10 @@
https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR
"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-from absl import logging
from dm_control.mujoco import wrapper
-
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-try:
- import scipy.linalg as sp # pylint: disable=g-import-not-at-top
-except ImportError:
- sp = None
-
-
-def _solve_dare(a, b, q, r):
- """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration.
-
- Algebraic Riccati Equation:
- ```none
- P_{t-1} = Q + A' * P_{t} * A -
- A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A
- ```
-
- Args:
- a: A 2 dimensional numpy array, transition matrix A.
- b: A 2 dimensional numpy array, control matrix B.
- q: A 2 dimensional numpy array, symmetric positive definite cost matrix.
- r: A 2 dimensional numpy array, symmetric positive definite cost matrix
-
- Returns:
- A numpy array, a real symmetric matrix P which is the solution to DARE.
-
- Raises:
- RuntimeError: If the computed P matrix is not symmetric and
- positive-definite.
- """
- p = np.eye(len(a))
- for _ in xrange(1000000):
- a_p = a.T.dot(p) # A' * P_t
- a_p_b = np.dot(a_p, b) # A' * P_t * B
- # Algebraic Riccati Equation.
- p_next = q + np.dot(a_p, a) - a_p_b.dot(
- np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T))
- p_next += p_next.T
- p_next *= .5
- if np.abs(p - p_next).max() < 1e-12:
- break
- p = p_next
- else:
- logging.warn('DARE solver did not converge')
- try:
- # Check that the result is symmetric and positive-definite.
- np.linalg.cholesky(p_next)
- except np.linalg.LinAlgError:
- raise RuntimeError('ARE solver failed: P matrix is not symmetric and '
- 'positive-definite.')
- return p_next
+import scipy.linalg as scipy_linalg
def solve(env):
@@ -117,7 +60,7 @@ def solve(env):
(dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j))
# Control transition matrix b.
- b = env.physics.data.actuator_moment.T
+ b = np.vstack((np.eye(m), np.zeros((n - m, m))))
bc = np.linalg.solve(mass, b)
b = dt * np.vstack((dt * bc, bc))
@@ -127,15 +70,8 @@ def solve(env):
# Control cost Hessian r.
r = env.task.control_cost_coef * np.eye(m)
- if sp:
- # Use scipy's faster DARE solver if available.
- solve_dare = sp.solve_discrete_are
- else:
- # Otherwise fall back on a slower internal implementation.
- solve_dare = _solve_dare
-
# Solve the discrete algebraic Riccati equation.
- p = solve_dare(a, b, q, r)
+ p = scipy_linalg.solve_discrete_are(a, b, q, r)
k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a)))
# Under optimal policy, state tends to 0 like beta^n_timesteps
diff --git a/dm_control/suite/tests/lqr_test.py b/dm_control/suite/lqr_test.py
similarity index 73%
rename from dm_control/suite/tests/lqr_test.py
rename to dm_control/suite/lqr_test.py
index 214fe822..0dae8ae1 100644
--- a/dm_control/suite/tests/lqr_test.py
+++ b/dm_control/suite/lqr_test.py
@@ -15,24 +15,14 @@
"""Tests specific to the LQR domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import math
-import unittest
-# Internal dependencies.
from absl import logging
-
from absl.testing import absltest
from absl.testing import parameterized
-
from dm_control.suite import lqr
from dm_control.suite import lqr_solver
-
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
class LqrTest(parameterized.TestCase):
@@ -45,22 +35,6 @@ def test_lqr_optimal_policy(self, make_env):
p, k, beta = lqr_solver.solve(env)
self.assertPolicyisOptimal(env, p, k, beta)
- @parameterized.named_parameters(
- ('lqr_2_1', lqr.lqr_2_1),
- ('lqr_6_2', lqr.lqr_6_2))
- @unittest.skipUnless(
- condition=lqr_solver.sp,
- reason='scipy is not available, so non-scipy DARE solver is the default.')
- def test_lqr_optimal_policy_no_scipy(self, make_env):
- env = make_env()
- old_sp = lqr_solver.sp
- try:
- lqr_solver.sp = None # Force the solver to use the non-scipy code path.
- p, k, beta = lqr_solver.solve(env)
- finally:
- lqr_solver.sp = old_sp
- self.assertPolicyisOptimal(env, p, k, beta)
-
def assertPolicyisOptimal(self, env, p, k, beta):
tolerance = 1e-3
n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
@@ -71,7 +45,7 @@ def assertPolicyisOptimal(self, env, p, k, beta):
initial_state = np.hstack((timestep.observation['position'],
timestep.observation['velocity']))
logging.info('Measuring total cost over %d steps.', n_steps)
- for _ in xrange(n_steps):
+ for _ in range(n_steps):
x = np.hstack((timestep.observation['position'],
timestep.observation['velocity']))
# u = k*x is the optimal policy
diff --git a/dm_control/suite/manipulator.py b/dm_control/suite/manipulator.py
index 3e253b16..bb66c99c 100644
--- a/dm_control/suite/manipulator.py
+++ b/dm_control/suite/manipulator.py
@@ -15,16 +15,9 @@
"""Planar Manipulator domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
-from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
@@ -44,6 +37,8 @@
'finger', 'fingertip', 'thumb', 'thumbtip']
_ALL_PROPS = frozenset(['ball', 'target_ball', 'cup',
'peg', 'target_peg', 'slot'])
+_TOUCH_SENSORS = ['palm_touch', 'finger_touch', 'thumb_touch',
+ 'fingertip_touch', 'thumbtip_touch']
SUITE = containers.TaggedTasks()
@@ -73,84 +68,90 @@ def make_model(use_peg, insert):
@SUITE.add('benchmarking', 'hard')
-def bring_ball(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+def bring_ball(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns manipulator bring task with the ball prop."""
use_peg = False
insert = False
physics = Physics.from_xml_string(*make_model(use_peg, insert))
- task = Bring(use_peg, insert, observe_target, random=random)
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
@SUITE.add('hard')
-def bring_peg(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+def bring_peg(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns manipulator bring task with the peg prop."""
use_peg = True
insert = False
physics = Physics.from_xml_string(*make_model(use_peg, insert))
- task = Bring(use_peg, insert, observe_target, random=random)
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
@SUITE.add('hard')
-def insert_ball(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+def insert_ball(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns manipulator insert task with the ball prop."""
use_peg = False
insert = True
physics = Physics.from_xml_string(*make_model(use_peg, insert))
- task = Bring(use_peg, insert, observe_target, random=random)
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
@SUITE.add('hard')
-def insert_peg(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+def insert_peg(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns manipulator insert task with the peg prop."""
use_peg = True
insert = True
physics = Physics.from_xml_string(*make_model(use_peg, insert))
- task = Bring(use_peg, insert, observe_target, random=random)
+ task = Bring(use_peg=use_peg, insert=insert,
+ fully_observable=fully_observable, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
class Physics(mujoco.Physics):
"""Physics with additional features for the Planar Manipulator domain."""
- def bounded_position(self):
- """Returns the position, with unbounded angles as sine/cosine."""
- state = []
- hinge_joint = enums.mjtJoint.mjJNT_HINGE
- for joint_id in range(self.model.njnt):
- joint_value = self.named.data.qpos[joint_id]
- if (not self.model.jnt_limited[joint_id] and
- self.model.jnt_type[joint_id] == hinge_joint): # Unbounded hinge.
- state += [np.sin(joint_value), np.cos(joint_value)]
- else:
- state.append(joint_value)
- return np.asarray(state)
-
- def body_location(self, body):
- """Returns the x,z position and y orientation of a body."""
- body_position = self.named.model.body_pos[body, ['x', 'z']]
- body_orientation = self.named.model.body_quat[body, ['qw', 'qy']]
- return np.hstack((body_position, body_orientation))
-
- def proprioception(self):
- """Returns the arm state, with unbounded angles as sine/cosine."""
- arm = []
- for joint in _ARM_JOINTS:
- joint_value = self.named.data.qpos[joint]
- if not self.named.model.jnt_limited[joint]:
- arm += [np.sin(joint_value), np.cos(joint_value)]
- else:
- arm.append(joint_value)
- return np.hstack(arm + [self.named.data.qvel[_ARM_JOINTS]])
+ def bounded_joint_pos(self, joint_names):
+ """Returns joint positions as (sin, cos) values."""
+ joint_pos = self.named.data.qpos[joint_names]
+ return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T
+
+ def joint_vel(self, joint_names):
+ """Returns joint velocities."""
+ return self.named.data.qvel[joint_names]
+
+ def body_2d_pose(self, body_names, orientation=True):
+ """Returns positions and/or orientations of bodies."""
+ if not isinstance(body_names, str):
+ body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices.
+ pos = self.named.data.xpos[body_names, ['x', 'z']]
+ if orientation:
+ ori = self.named.data.xquat[body_names, ['qw', 'qy']]
+ return np.hstack([pos, ori])
+ else:
+ return pos
def touch(self):
- return np.log1p(self.data.sensordata)
+ return np.log1p(self.named.data.sensordata[_TOUCH_SENSORS])
def site_distance(self, site1, site2):
site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0)
@@ -160,13 +161,15 @@ def site_distance(self, site1, site2):
class Bring(base.Task):
"""A Bring `Task`: bring the prop to the target."""
- def __init__(self, use_peg, insert, observe_target, random=None):
+ def __init__(self, use_peg, insert, fully_observable, random=None):
"""Initialize an instance of the `Bring` task.
Args:
use_peg: A `bool`, whether to replace the ball prop with the peg prop.
insert: A `bool`, whether to insert the prop in a receptacle.
- observe_target: A `bool`, whether the observation contains target info.
+ fully_observable: A `bool`, whether the observation should contain the
+ position and velocity of the object being manipulated and the target
+ location.
random: Optional, either a `numpy.random.RandomState` instance, an
integer seed for creating a new `RandomState`, or None to select a seed
automatically (default).
@@ -174,14 +177,16 @@ def __init__(self, use_peg, insert, observe_target, random=None):
self._use_peg = use_peg
self._target = 'target_peg' if use_peg else 'target_ball'
self._object = 'peg' if self._use_peg else 'ball'
+ self._object_joints = ['_'.join([self._object, dim]) for dim in 'xzy']
self._receptacle = 'slot' if self._use_peg else 'cup'
self._insert = insert
- self._observe_target = observe_target
- super(Bring, self).__init__(random=random)
+ self._fully_observable = fully_observable
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
- # local shortcuts
+ # Local aliases
+ choice = self.random.choice
uniform = self.random.uniform
model = physics.named.model
data = physics.named.data
@@ -191,7 +196,7 @@ def initialize_episode(self, physics):
while penetrating:
# Randomise angles of arm joints.
- is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
+ is_limited = model.jnt_limited[_ARM_JOINTS].astype(bool)
joint_range = model.jnt_range[_ARM_JOINTS]
lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
@@ -218,8 +223,8 @@ def initialize_episode(self, physics):
# Randomise object location.
object_init_probs = [_P_IN_HAND, _P_IN_TARGET, 1-_P_IN_HAND-_P_IN_TARGET]
- init_type = np.random.choice(['in_hand', 'in_target', 'uniform'], 1,
- p=object_init_probs)[0]
+ init_type = choice(['in_hand', 'in_target', 'uniform'],
+ p=object_init_probs)
if init_type == 'in_target':
object_x = target_x
object_z = target_z
@@ -236,26 +241,25 @@ def initialize_episode(self, physics):
object_angle = uniform(0, 2*np.pi)
data.qvel[self._object + '_x'] = uniform(-5, 5)
- data.qpos[self._object + '_x'] = object_x
- data.qpos[self._object + '_z'] = object_z
- data.qpos[self._object + '_y'] = object_angle
+ data.qpos[self._object_joints] = object_x, object_z, object_angle
# Check for collisions.
physics.after_reset()
penetrating = physics.data.ncon > 0
+ super().initialize_episode(physics)
+
def get_observation(self, physics):
"""Returns either features or only sensors (to be used with pixels)."""
obs = collections.OrderedDict()
- if self._observe_target:
- obs['position'] = physics.bounded_position()
- obs['hand'] = physics.body_location('hand')
- obs['target'] = physics.body_location(self._target)
- obs['velocity'] = physics.velocity()
- obs['touch'] = physics.touch()
- else:
- obs['proprioception'] = physics.proprioception()
- obs['touch'] = physics.touch()
+ obs['arm_pos'] = physics.bounded_joint_pos(_ARM_JOINTS)
+ obs['arm_vel'] = physics.joint_vel(_ARM_JOINTS)
+ obs['touch'] = physics.touch()
+ if self._fully_observable:
+ obs['hand_pos'] = physics.body_2d_pose('hand')
+ obs['object_pos'] = physics.body_2d_pose(self._object)
+ obs['object_vel'] = physics.joint_vel(self._object_joints)
+ obs['target_pos'] = physics.body_2d_pose(self._target)
return obs
def _is_close(self, distance):
diff --git a/dm_control/suite/manipulator.xml b/dm_control/suite/manipulator.xml
index 6e9b2014..f12c9cc2 100644
--- a/dm_control/suite/manipulator.xml
+++ b/dm_control/suite/manipulator.xml
@@ -11,7 +11,7 @@
- >
+
@@ -20,7 +20,7 @@
-
+
@@ -161,7 +161,7 @@
-
+
diff --git a/dm_control/suite/pendulum.py b/dm_control/suite/pendulum.py
index f6331f66..77063262 100644
--- a/dm_control/suite/pendulum.py
+++ b/dm_control/suite/pendulum.py
@@ -15,21 +15,14 @@
"""Pendulum domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
@@ -45,11 +38,14 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns pendulum swingup task ."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = SwingUp(random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
class Physics(mujoco.Physics):
@@ -61,7 +57,7 @@ def pole_vertical(self):
def angular_velocity(self):
"""Returns the angular velocity of the pole."""
- return self.named.data.qvel['hinge']
+ return self.named.data.qvel['hinge'].copy()
def pole_orientation(self):
"""Returns both horizontal and vertical components of pole frame."""
@@ -79,7 +75,7 @@ def __init__(self, random=None):
integer seed for creating a new `RandomState`, or None to select a seed
automatically (default).
"""
- super(SwingUp, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode.
@@ -91,6 +87,7 @@ def initialize_episode(self, physics):
"""
physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation.
diff --git a/dm_control/suite/point_mass.py b/dm_control/suite/point_mass.py
index c499e00e..6c4d516b 100644
--- a/dm_control/suite/point_mass.py
+++ b/dm_control/suite/point_mass.py
@@ -15,14 +15,8 @@
"""Point-mass domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -30,7 +24,6 @@
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
_DEFAULT_TIME_LIMIT = 20
@@ -43,19 +36,23 @@ def get_model_and_assets():
@SUITE.add('benchmarking', 'easy')
-def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the easy point_mass task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = PointMass(randomize_gains=False, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add()
-def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the hard point_mass task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = PointMass(randomize_gains=True, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
class Physics(mujoco.Physics):
@@ -84,7 +81,7 @@ def __init__(self, randomize_gains, random=None):
automatically (default).
"""
self._randomize_gains = randomize_gains
- super(PointMass, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode.
@@ -108,6 +105,7 @@ def initialize_episode(self, physics):
parallel = abs(np.dot(dir1, dir2)) > 0.9
physics.model.wrap_prm[[0, 1]] = dir1
physics.model.wrap_prm[[2, 3]] = dir2
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of the state."""
diff --git a/dm_control/suite/quadruped.py b/dm_control/suite/quadruped.py
new file mode 100644
index 00000000..86fb5424
--- /dev/null
+++ b/dm_control/suite/quadruped.py
@@ -0,0 +1,477 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Quadruped Domain."""
+
+import collections
+
+from dm_control import mujoco
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+from lxml import etree
+import numpy as np
+from scipy import ndimage
+
+enums = mjbindings.enums
+mjlib = mjbindings.mjlib
+
+
+_DEFAULT_TIME_LIMIT = 20
+_CONTROL_TIMESTEP = .02
+
+# Horizontal speeds above which the move reward is 1.
+_RUN_SPEED = 5
+_WALK_SPEED = 0.5
+
+# Constants related to terrain generation.
+_HEIGHTFIELD_ID = 0
+_TERRAIN_SMOOTHNESS = 0.15 # 0.0: maximally bumpy; 1.0: completely smooth.
+_TERRAIN_BUMP_SCALE = 2 # Spatial scale of terrain bumps (in meters).
+
+# Named model elements.
+_TOES = ['toe_front_left', 'toe_back_left', 'toe_back_right', 'toe_front_right']
+_WALLS = ['wall_px', 'wall_py', 'wall_nx', 'wall_ny']
+
+SUITE = containers.TaggedTasks()
+
+
+def make_model(floor_size=None, terrain=False, rangefinders=False,
+ walls_and_ball=False):
+ """Returns the model XML string."""
+ xml_string = common.read_model('quadruped.xml')
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # Set floor size.
+ if floor_size is not None:
+ floor_geom = mjcf.find('.//geom[@name=\'floor\']')
+ floor_geom.attrib['size'] = f'{floor_size} {floor_size} .5'
+
+ # Remove walls, ball and target.
+ if not walls_and_ball:
+ for wall in _WALLS:
+ wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
+ wall_geom.getparent().remove(wall_geom)
+
+ # Remove ball.
+ ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
+ ball_body.getparent().remove(ball_body)
+
+ # Remove target.
+ target_site = xml_tools.find_element(mjcf, 'site', 'target')
+ target_site.getparent().remove(target_site)
+
+ # Remove terrain.
+ if not terrain:
+ terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
+ terrain_geom.getparent().remove(terrain_geom)
+
+ # Remove rangefinders if they're not used, as range computations can be
+ # expensive, especially in a scene with heightfields.
+ if not rangefinders:
+ rangefinder_sensors = mjcf.findall('.//rangefinder')
+ for rf in rangefinder_sensors:
+ rf.getparent().remove(rf)
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+@SUITE.add()
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Move(desired_speed=_WALK_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _RUN_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Move(desired_speed=_RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def escape(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
+ """Returns the Escape task."""
+ xml_string = make_model(floor_size=40, terrain=True, rangefinders=True)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Escape(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add()
+def fetch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Fetch task."""
+ xml_string = make_model(walls_and_ball=True)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Fetch(random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Quadruped domain."""
+
+ def _reload_from_data(self, data):
+ super()._reload_from_data(data)
+ # Clear cached sensor names when the physics is reloaded.
+ self._sensor_types_to_names = {}
+ self._hinge_names = []
+
+ def _get_sensor_names(self, *sensor_types):
+ try:
+ sensor_names = self._sensor_types_to_names[sensor_types]
+ except KeyError:
+ [sensor_ids] = np.where(
+ np.isin(self.model.sensor_type, sensor_types).ravel()
+ )
+ sensor_names = [self.model.id2name(s_id, 'sensor') for s_id in sensor_ids]
+ self._sensor_types_to_names[sensor_types] = sensor_names
+ return sensor_names
+
+ def torso_upright(self):
+ """Returns the dot-product of the torso z-axis and the global z-axis."""
+ return np.asarray(self.named.data.xmat['torso', 'zz'])
+
+ def torso_velocity(self):
+ """Returns the velocity of the torso, in the local frame."""
+ return self.named.data.sensordata['velocimeter'].copy()
+
+ def egocentric_state(self):
+ """Returns the state without global orientation or position."""
+ if not self._hinge_names:
+ [hinge_ids] = np.nonzero(self.model.jnt_type ==
+ enums.mjtJoint.mjJNT_HINGE)
+ self._hinge_names = [self.model.id2name(j_id, 'joint')
+ for j_id in hinge_ids]
+ return np.hstack((self.named.data.qpos[self._hinge_names],
+ self.named.data.qvel[self._hinge_names],
+ self.data.act))
+
+ def toe_positions(self):
+ """Returns toe positions in egocentric frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ torso_to_toe = self.named.data.xpos[_TOES] - torso_pos
+ return torso_to_toe.dot(torso_frame)
+
+ def force_torque(self):
+ """Returns scaled force/torque sensor readings at the toes."""
+ force_torque_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_FORCE,
+ enums.mjtSensor.mjSENS_TORQUE)
+ return np.arcsinh(self.named.data.sensordata[force_torque_sensors])
+
+ def imu(self):
+ """Returns IMU-like sensor readings."""
+ imu_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_GYRO,
+ enums.mjtSensor.mjSENS_ACCELEROMETER)
+ return self.named.data.sensordata[imu_sensors]
+
+ def rangefinder(self):
+ """Returns scaled rangefinder sensor readings."""
+ rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER)
+ rf_readings = self.named.data.sensordata[rf_sensors]
+ no_intersection = -1.0
+ return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings))
+
+ def origin_distance(self):
+ """Returns the distance from the origin to the workspace."""
+ return np.asarray(np.linalg.norm(self.named.data.site_xpos['workspace']))
+
+ def origin(self):
+ """Returns origin position in the torso frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ return -torso_pos.dot(torso_frame)
+
+ def ball_state(self):
+ """Returns ball position and velocity relative to the torso frame."""
+ data = self.named.data
+ torso_frame = data.xmat['torso'].reshape(3, 3)
+ ball_rel_pos = data.xpos['ball'] - data.xpos['torso']
+ ball_rel_vel = data.qvel['ball_root'][:3] - data.qvel['root'][:3]
+ ball_rot_vel = data.qvel['ball_root'][3:]
+ ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel))
+ return ball_state.dot(torso_frame).ravel()
+
+ def target_position(self):
+ """Returns target position in torso frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ torso_to_target = self.named.data.site_xpos['target'] - torso_pos
+ return torso_to_target.dot(torso_frame)
+
+ def ball_to_target_distance(self):
+ """Returns horizontal distance from the ball to the target."""
+ ball_to_target = (self.named.data.site_xpos['target'] -
+ self.named.data.xpos['ball'])
+ return np.linalg.norm(ball_to_target[:2])
+
+ def self_to_ball_distance(self):
+ """Returns horizontal distance from the quadruped workspace to the ball."""
+ self_to_ball = (self.named.data.site_xpos['workspace']
+ -self.named.data.xpos['ball'])
+ return np.linalg.norm(self_to_ball[:2])
+
+
+def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0):
+ """Find a height with no contacts given a body orientation.
+
+ Args:
+ physics: An instance of `Physics`.
+ orientation: A quaternion.
+ x_pos: A float. Position along global x-axis.
+ y_pos: A float. Position along global y-axis.
+ Raises:
+ RuntimeError: If a non-contacting configuration has not been found after
+ 10,000 attempts.
+ """
+ z_pos = 0.0 # Start embedded in the floor.
+ num_contacts = 1
+ num_attempts = 0
+ # Move up in 1cm increments until no contacts.
+ while num_contacts > 0:
+ try:
+ with physics.reset_context():
+ physics.named.data.qpos['root'][:3] = x_pos, y_pos, z_pos
+ physics.named.data.qpos['root'][3:] = orientation
+ except control.PhysicsError:
+ # We may encounter a PhysicsError here due to filling the contact
+ # buffer, in which case we simply increment the height and continue.
+ pass
+ num_contacts = physics.data.ncon
+ z_pos += 0.01
+ num_attempts += 1
+ if num_attempts > 10000:
+ raise RuntimeError('Failed to find a non-contacting configuration.')
+
+
+def _common_observations(physics):
+ """Returns the observations common to all tasks."""
+ obs = collections.OrderedDict()
+ obs['egocentric_state'] = physics.egocentric_state()
+ obs['torso_velocity'] = physics.torso_velocity()
+ obs['torso_upright'] = physics.torso_upright()
+ obs['imu'] = physics.imu()
+ obs['force_torque'] = physics.force_torque()
+ return obs
+
+
+def _upright_reward(physics, deviation_angle=0):
+ """Returns a reward proportional to how upright the torso is.
+
+ Args:
+ physics: an instance of `Physics`.
+ deviation_angle: A float, in degrees. The reward is 0 when the torso is
+ exactly upside-down and 1 when the torso's z-axis is less than
+ `deviation_angle` away from the global z-axis.
+ """
+ deviation = np.cos(np.deg2rad(deviation_angle))
+ return rewards.tolerance(
+ physics.torso_upright(),
+ bounds=(deviation, float('inf')),
+ sigmoid='linear',
+ margin=1 + deviation,
+ value_at_margin=0)
+
+
+class Move(base.Task):
+ """A quadruped task solved by moving forward at a designated speed."""
+
+ def __init__(self, desired_speed, random=None):
+ """Initializes an instance of `Move`.
+
+ Args:
+ desired_speed: A float. If this value is zero, reward is given simply
+ for standing upright. Otherwise this specifies the horizontal velocity
+ at which the velocity-dependent reward component is maximized.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._desired_speed = desired_speed
+ super().__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ return _common_observations(physics)
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Move reward term.
+ move_reward = rewards.tolerance(
+ physics.torso_velocity()[0],
+ bounds=(self._desired_speed, float('inf')),
+ margin=self._desired_speed,
+ value_at_margin=0.5,
+ sigmoid='linear')
+
+ return _upright_reward(physics) * move_reward
+
+
+class Escape(base.Task):
+ """A quadruped task solved by escaping a bowl-shaped terrain."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Get heightfield resolution, assert that it is square.
+ res = physics.model.hfield_nrow[_HEIGHTFIELD_ID]
+ assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID]
+ # Sinusoidal bowl shape.
+ row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j]
+ radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .04, 1)
+ bowl_shape = .5 - np.cos(2*np.pi*radius)/2
+ # Random smooth bumps.
+ terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
+ bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
+ smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
+ # Terrain is elementwise product.
+ terrain = bowl_shape * smooth_bumps
+ start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID]
+ physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel()
+ super().initialize_episode(physics)
+
+ # If we have a rendering context, we need to re-upload the modified
+ # heightfield data.
+ if physics.contexts:
+ with physics.contexts.gl.make_current() as ctx:
+ ctx.call(mjlib.mjr_uploadHField,
+ physics.model.ptr,
+ physics.contexts.mujoco.ptr,
+ _HEIGHTFIELD_ID)
+
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ obs = _common_observations(physics)
+ obs['origin'] = physics.origin()
+ obs['rangefinder'] = physics.rangefinder()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Escape reward term.
+ terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ escape_reward = rewards.tolerance(
+ physics.origin_distance(),
+ bounds=(terrain_size, float('inf')),
+ margin=terrain_size,
+ value_at_margin=0,
+ sigmoid='linear')
+
+ return _upright_reward(physics, deviation_angle=20) * escape_reward
+
+
+class Fetch(base.Task):
+ """A quadruped task solved by bringing a ball to the origin."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Initial configuration, random azimuth and horizontal position.
+ azimuth = self.random.uniform(0, 2*np.pi)
+ orientation = np.array((np.cos(azimuth/2), 0, 0, np.sin(azimuth/2)))
+ spawn_radius = 0.9 * physics.named.model.geom_size['floor', 0]
+ x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,))
+ _find_non_contacting_height(physics, orientation, x_pos, y_pos)
+
+ # Initial ball state.
+ physics.named.data.qpos['ball_root'][:2] = self.random.uniform(
+ -spawn_radius, spawn_radius, size=(2,))
+ physics.named.data.qpos['ball_root'][2] = 2
+ physics.named.data.qvel['ball_root'][:2] = 5*self.random.randn(2)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ obs = _common_observations(physics)
+ obs['ball_state'] = physics.ball_state()
+ obs['target_position'] = physics.target_position()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Reward for moving close to the ball.
+ arena_radius = physics.named.model.geom_size['floor', 0] * np.sqrt(2)
+ workspace_radius = physics.named.model.site_size['workspace', 0]
+ ball_radius = physics.named.model.geom_size['ball', 0]
+ reach_reward = rewards.tolerance(
+ physics.self_to_ball_distance(),
+ bounds=(0, workspace_radius+ball_radius),
+ sigmoid='linear',
+ margin=arena_radius, value_at_margin=0)
+
+ # Reward for bringing the ball to the target.
+ target_radius = physics.named.model.site_size['target', 0]
+ fetch_reward = rewards.tolerance(
+ physics.ball_to_target_distance(),
+ bounds=(0, target_radius),
+ sigmoid='linear',
+ margin=arena_radius, value_at_margin=0)
+
+ reach_then_fetch = reach_reward * (0.5 + 0.5*fetch_reward)
+
+ return _upright_reward(physics) * reach_then_fetch
diff --git a/dm_control/suite/quadruped.xml b/dm_control/suite/quadruped.xml
new file mode 100644
index 00000000..958d2c0d
--- /dev/null
+++ b/dm_control/suite/quadruped.xml
@@ -0,0 +1,329 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/reacher.py b/dm_control/suite/reacher.py
index 4b701e26..aae3240d 100644
--- a/dm_control/suite/reacher.py
+++ b/dm_control/suite/reacher.py
@@ -15,14 +15,8 @@
"""Reacher domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -30,7 +24,6 @@
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
-
import numpy as np
SUITE = containers.TaggedTasks()
@@ -45,28 +38,32 @@ def get_model_and_assets():
@SUITE.add('benchmarking', 'easy')
-def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns reacher with sparse reward with 5e-2 tol and randomized target."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Reacher(target_size=_BIG_TARGET, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
@SUITE.add('benchmarking')
-def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns reacher with sparse reward with 1e-2 tol and randomized target."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = Reacher(target_size=_SMALL_TARGET, random=random)
- return control.Environment(physics, task, time_limit=time_limit)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, **environment_kwargs)
class Physics(mujoco.Physics):
"""Physics simulation with additional features for the Reacher domain."""
def finger_to_target(self):
- """Returns the vector from target to finger in global coordinate."""
- return (self.named.data.geom_xpos['target'] -
- self.named.data.geom_xpos['finger'])
+ """Returns the vector from target to finger in global coordinates."""
+ return (self.named.data.geom_xpos['target', :2] -
+ self.named.data.geom_xpos['finger', :2])
def finger_to_target_dist(self):
"""Returns the signed distance between the finger and target surface."""
@@ -87,19 +84,21 @@ def __init__(self, target_size, random=None):
automatically (default).
"""
self._target_size = target_size
- super(Reacher, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
physics.named.model.geom_size['target', 0] = self._target_size
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
- # randomize target position
+ # Randomize target position
angle = self.random.uniform(0, 2 * np.pi)
radius = self.random.uniform(.05, .20)
physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
+ super().initialize_episode(physics)
+
def get_observation(self, physics):
"""Returns an observation of the state and the target position."""
obs = collections.OrderedDict()
diff --git a/dm_control/suite/stacker.py b/dm_control/suite/stacker.py
index 8de91d27..4474b108 100644
--- a/dm_control/suite/stacker.py
+++ b/dm_control/suite/stacker.py
@@ -15,16 +15,9 @@
"""Planar Stacker domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
-from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
@@ -60,57 +53,57 @@ def make_model(n_boxes):
@SUITE.add('hard')
-def stack_2(observable=True, time_limit=_TIME_LIMIT, random=None):
+def stack_2(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns stacker task with 2 boxes."""
n_boxes = 2
physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
- task = Stack(n_boxes, observable, random=random)
+ task = Stack(n_boxes=n_boxes,
+ fully_observable=fully_observable,
+ random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
@SUITE.add('hard')
-def stack_4(observable=True, time_limit=_TIME_LIMIT, random=None):
+def stack_4(fully_observable=True, time_limit=_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns stacker task with 4 boxes."""
n_boxes = 4
physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
- task = Stack(n_boxes, observable, random=random)
+ task = Stack(n_boxes=n_boxes,
+ fully_observable=fully_observable,
+ random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
+ **environment_kwargs)
class Physics(mujoco.Physics):
"""Physics with additional features for the Planar Manipulator domain."""
- def bounded_position(self):
- """Returns the state, with unbounded angles as sine/cosine."""
- state = []
- hinge_joint = enums.mjtJoint.mjJNT_HINGE
- for joint_id in range(self.model.njnt):
- joint_value = self.named.data.qpos[joint_id]
- if (not self.model.jnt_limited[joint_id] and
- self.model.jnt_type[joint_id] == hinge_joint): # Unbounded hinge.
- state += [np.sin(joint_value), np.cos(joint_value)]
- else:
- state.append(joint_value)
- return np.asarray(state)
-
- def body_location(self, body):
- """Returns the x,z position and y orientation of a body."""
- body_position = self.named.model.body_pos[body, ['x', 'z']]
- body_orientation = self.named.model.body_quat[body, ['qw', 'qy']]
- return np.hstack((body_position, body_orientation))
-
- def proprioception(self):
- """Returns the arm state, with unbounded angles as sine/cosine."""
- arm = []
- for joint in _ARM_JOINTS:
- joint_value = self.named.data.qpos[joint]
- if not self.named.model.jnt_limited[joint]:
- arm += [np.sin(joint_value), np.cos(joint_value)]
- else:
- arm.append(joint_value)
- return np.hstack(arm + [self.named.data.qvel[_ARM_JOINTS]])
+ def bounded_joint_pos(self, joint_names):
+ """Returns joint positions as (sin, cos) values."""
+ joint_pos = self.named.data.qpos[joint_names]
+ return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T
+
+ def joint_vel(self, joint_names):
+ """Returns joint velocities."""
+ return self.named.data.qvel[joint_names]
+
+ def body_2d_pose(self, body_names, orientation=True):
+ """Returns positions and/or orientations of bodies."""
+ if not isinstance(body_names, str):
+ body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices.
+ pos = self.named.data.xpos[body_names, ['x', 'z']]
+ if orientation:
+ ori = self.named.data.xquat[body_names, ['qw', 'qy']]
+ return np.hstack([pos, ori])
+ else:
+ return pos
def touch(self):
return np.log1p(self.data.sensordata)
@@ -123,23 +116,30 @@ def site_distance(self, site1, site2):
class Stack(base.Task):
"""A Stack `Task`: stack the boxes."""
- def __init__(self, n_boxes, observable, random=None):
+ def __init__(self, n_boxes, fully_observable, random=None):
"""Initialize an instance of the `Stack` task.
Args:
n_boxes: An `int`, number of boxes to stack.
- observable: A `bool`, whether the observation contains target info.
+ fully_observable: A `bool`, whether the observation should contain the
+ positions and velocities of the boxes and the location of the target.
random: Optional, either a `numpy.random.RandomState` instance, an
integer seed for creating a new `RandomState`, or None to select a seed
automatically (default).
"""
self._n_boxes = n_boxes
- self._observable = observable
- super(Stack, self).__init__(random=random)
+ self._box_names = ['box' + str(b) for b in range(n_boxes)]
+ self._box_joint_names = []
+ for name in self._box_names:
+ for dim in 'xyz':
+ self._box_joint_names.append('_'.join([name, dim]))
+ self._fully_observable = fully_observable
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
- # local shortcuts
+ # Local aliases
+ randint = self.random.randint
uniform = self.random.uniform
model = physics.named.model
data = physics.named.data
@@ -149,7 +149,7 @@ def initialize_episode(self, physics):
while penetrating:
# Randomise angles of arm joints.
- is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
+ is_limited = model.jnt_limited[_ARM_JOINTS].astype(bool)
joint_range = model.jnt_range[_ARM_JOINTS]
lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
@@ -160,45 +160,45 @@ def initialize_episode(self, physics):
data.qpos['finger'] = data.qpos['thumb']
# Randomise target location.
- target_height = 2*np.random.randint(self._n_boxes) + 1
+ target_height = 2*randint(self._n_boxes) + 1
box_size = model.geom_size['target', 0]
model.body_pos['target', 'z'] = box_size * target_height
model.body_pos['target', 'x'] = uniform(-.37, .37)
# Randomise box locations.
- for b in range(self._n_boxes):
- box = 'box' + str(b)
- data.qpos[box + '_x'] = uniform(.1, .3)
- data.qpos[box + '_z'] = uniform(0, .7)
- data.qpos[box + '_y'] = uniform(0, 2*np.pi)
+ for name in self._box_names:
+ data.qpos[name + '_x'] = uniform(.1, .3)
+ data.qpos[name + '_z'] = uniform(0, .7)
+ data.qpos[name + '_y'] = uniform(0, 2*np.pi)
# Check for collisions.
physics.after_reset()
penetrating = physics.data.ncon > 0
+ super().initialize_episode(physics)
+
def get_observation(self, physics):
"""Returns either features or only sensors (to be used with pixels)."""
obs = collections.OrderedDict()
- if self._observable:
- box_locations = [physics.body_location('box' + str(b))
- for b in range(self._n_boxes)]
- obs['position'] = physics.bounded_position()
- obs['hand'] = physics.body_location('hand')
- obs['boxes'] = np.hstack(box_locations)
- obs['velocity'] = physics.velocity()
- obs['touch'] = physics.touch()
- else:
- obs['proprioception'] = physics.proprioception()
- obs['touch'] = physics.touch()
+ obs['arm_pos'] = physics.bounded_joint_pos(_ARM_JOINTS)
+ obs['arm_vel'] = physics.joint_vel(_ARM_JOINTS)
+ obs['touch'] = physics.touch()
+ if self._fully_observable:
+ obs['hand_pos'] = physics.body_2d_pose('hand')
+ obs['box_pos'] = physics.body_2d_pose(self._box_names)
+ obs['box_vel'] = physics.joint_vel(self._box_joint_names)
+ obs['target_pos'] = physics.body_2d_pose('target', orientation=False)
return obs
def get_reward(self, physics):
"""Returns a reward to the agent."""
box_size = physics.named.model.geom_size['target', 0]
- def target_to_box(b):
- return rewards.tolerance(physics.site_distance('box' + str(b), 'target'),
- margin=2*box_size)
- box_is_close = max(target_to_box(b) for b in range(self._n_boxes))
- hand_to_target = physics.site_distance('grasp', 'target')
- hand_is_far = rewards.tolerance(hand_to_target, (.1, float('inf')), _CLOSE)
+ min_box_to_target_distance = min(physics.site_distance(name, 'target')
+ for name in self._box_names)
+ box_is_close = rewards.tolerance(min_box_to_target_distance,
+ margin=2*box_size)
+ hand_to_target_distance = physics.site_distance('grasp', 'target')
+ hand_is_far = rewards.tolerance(hand_to_target_distance,
+ bounds=(.1, float('inf')),
+ margin=_CLOSE)
return box_is_close * hand_is_far
diff --git a/dm_control/suite/stacker.xml b/dm_control/suite/stacker.xml
index 06e4846c..e57b5a9a 100644
--- a/dm_control/suite/stacker.xml
+++ b/dm_control/suite/stacker.xml
@@ -20,7 +20,7 @@
-
+
diff --git a/dm_control/suite/suite_test.py b/dm_control/suite/suite_test.py
new file mode 100644
index 00000000..7ad36dd0
--- /dev/null
+++ b/dm_control/suite/suite_test.py
@@ -0,0 +1,291 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.suite domains."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import suite
+from dm_control.mujoco.wrapper.mjbindings import constants
+from dm_control.rl import control
+import mock
+import numpy as np
+
+
+_DOMAINS_AND_TASKS = [
+ dict(domain=domain, task=task) for domain, task in suite.ALL_TASKS
+]
+
+
+def uniform_random_policy(action_spec, random=None):
+ lower_bounds = action_spec.minimum
+ upper_bounds = action_spec.maximum
+ # Draw values between -1 and 1 for actions where the min/max is set to
+ # MuJoCo's internal limit.
+ lower_bounds = np.where(
+ np.abs(lower_bounds) >= constants.mjMAXVAL, -1.0, lower_bounds)
+ upper_bounds = np.where(
+ np.abs(upper_bounds) >= constants.mjMAXVAL, 1.0, upper_bounds)
+ random_state = np.random.RandomState(random)
+ def policy(time_step):
+ del time_step # Unused.
+ return random_state.uniform(lower_bounds, upper_bounds)
+ return policy
+
+
+def step_environment(env, policy, num_episodes=5, max_steps_per_episode=10):
+ for _ in range(num_episodes):
+ step_count = 0
+ time_step = env.reset()
+ yield time_step
+ while not time_step.last():
+ action = policy(time_step)
+ time_step = env.step(action)
+ step_count += 1
+ yield time_step
+ if step_count >= max_steps_per_episode:
+ break
+
+
+def make_trajectory(domain, task, seed, **trajectory_kwargs):
+ env = suite.load(domain, task, task_kwargs={'random': seed})
+ policy = uniform_random_policy(env.action_spec(), random=seed)
+ return step_environment(env, policy, **trajectory_kwargs)
+
+
+class SuiteTest(parameterized.TestCase):
+ """Tests run on all the tasks registered."""
+
+ def test_constants(self):
+ num_tasks = sum(len(tasks) for tasks in suite.TASKS_BY_DOMAIN.values())
+ self.assertLen(suite.ALL_TASKS, num_tasks)
+
+ def _validate_observation(self, observation_dict, observation_spec):
+ obs = observation_dict.copy()
+ for name, spec in observation_spec.items():
+ arr = obs.pop(name)
+ self.assertEqual(arr.shape, spec.shape)
+ self.assertEqual(arr.dtype, spec.dtype)
+ self.assertTrue(
+ np.all(np.isfinite(arr)),
+ msg='{!r} has non-finite value(s): {!r}'.format(name, arr))
+ self.assertEmpty(
+ obs,
+ msg='Observation contains arrays(s) that are not in the spec: {!r}'
+ .format(obs))
+
+ def _validate_reward_range(self, time_step):
+ if time_step.first():
+ self.assertIsNone(time_step.reward)
+ else:
+ self.assertIsInstance(time_step.reward, float)
+ self.assertBetween(time_step.reward, 0, 1)
+
+ def _validate_discount(self, time_step):
+ if time_step.first():
+ self.assertIsNone(time_step.discount)
+ else:
+ self.assertIsInstance(time_step.discount, float)
+ self.assertBetween(time_step.discount, 0, 1)
+
+ def _validate_control_range(self, lower_bounds, upper_bounds):
+ for b in lower_bounds:
+ self.assertEqual(b, -1.0)
+ for b in upper_bounds:
+ self.assertEqual(b, 1.0)
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_components_have_names(self, domain, task):
+ env = suite.load(domain, task)
+ model = env.physics.model
+
+ object_types_and_size_fields = [
+ ('body', 'nbody'),
+ ('joint', 'njnt'),
+ ('geom', 'ngeom'),
+ ('site', 'nsite'),
+ ('camera', 'ncam'),
+ ('light', 'nlight'),
+ ('mesh', 'nmesh'),
+ ('hfield', 'nhfield'),
+ ('texture', 'ntex'),
+ ('material', 'nmat'),
+ ('equality', 'neq'),
+ ('tendon', 'ntendon'),
+ ('actuator', 'nu'),
+ ('sensor', 'nsensor'),
+ ('numeric', 'nnumeric'),
+ ('text', 'ntext'),
+ ('tuple', 'ntuple'),
+ ]
+ for object_type, size_field in object_types_and_size_fields:
+ for idx in range(getattr(model, size_field)):
+ object_name = model.id2name(idx, object_type)
+ self.assertNotEqual(object_name, '',
+ msg='Model {!r} contains unnamed {!r} with ID {}.'
+ .format(model.name, object_type, idx))
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_model_has_at_least_2_cameras(self, domain, task):
+ env = suite.load(domain, task)
+ model = env.physics.model
+ self.assertGreaterEqual(model.ncam, 2,
+ 'Model {!r} should have at least 2 cameras, has {}.'
+ .format(model.name, model.ncam))
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_task_conforms_to_spec(self, domain, task):
+ """Tests that the environment timesteps conform to specifications."""
+ is_benchmark = (domain, task) in suite.BENCHMARKING
+ env = suite.load(domain, task)
+ observation_spec = env.observation_spec()
+ action_spec = env.action_spec()
+
+ # Check action bounds.
+ if is_benchmark:
+ self._validate_control_range(action_spec.minimum, action_spec.maximum)
+
+ # Step through the environment, applying random actions sampled within the
+ # valid range and check the observations, rewards, and discounts.
+ policy = uniform_random_policy(action_spec)
+ for time_step in step_environment(env, policy):
+ self._validate_observation(time_step.observation, observation_spec)
+ self._validate_discount(time_step)
+ if is_benchmark:
+ self._validate_reward_range(time_step)
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_environment_is_deterministic(self, domain, task):
+ """Tests that identical seeds and actions produce identical trajectories."""
+ seed = 0
+ # Iterate over two trajectories generated using identical sequences of
+ # random actions, and with identical task random states. Check that the
+ # observations, rewards, discounts and step types are identical.
+ trajectory1 = make_trajectory(domain=domain, task=task, seed=seed)
+ trajectory2 = make_trajectory(domain=domain, task=task, seed=seed)
+ for time_step1, time_step2 in zip(trajectory1, trajectory2):
+ self.assertEqual(time_step1.step_type, time_step2.step_type)
+ self.assertEqual(time_step1.reward, time_step2.reward)
+ self.assertEqual(time_step1.discount, time_step2.discount)
+ for key in time_step1.observation.keys():
+ np.testing.assert_array_equal(
+ time_step1.observation[key], time_step2.observation[key],
+ err_msg='Observation {!r} is not equal.'.format(key))
+
+ def assertCorrectColors(self, physics, reward):
+ colors = physics.named.model.mat_rgba
+ for material_name in ('self', 'effector', 'target'):
+ highlight = colors[material_name + '_highlight']
+ default = colors[material_name + '_default']
+ blend_coef = reward ** 4
+ expected = blend_coef * highlight + (1.0 - blend_coef) * default
+ actual = colors[material_name]
+ err_msg = ('Material {!r} has unexpected color.\nExpected: {!r}\n'
+ 'Actual: {!r}'.format(material_name, expected, actual))
+ np.testing.assert_array_almost_equal(expected, actual, err_msg=err_msg)
+
+ @parameterized.parameters(*suite.REWARD_VIZ)
+ def test_visualize_reward(self, domain, task):
+ env = suite.load(domain, task)
+ env.task.visualize_reward = True
+ action = np.zeros(env.action_spec().shape)
+
+ with mock.patch.object(env.task, 'get_reward') as mock_get_reward:
+ mock_get_reward.return_value = -3.0 # Rewards < 0 should be clipped.
+ env.reset()
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=0.0)
+
+ mock_get_reward.reset_mock()
+ mock_get_reward.return_value = 0.5
+ env.step(action)
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)
+
+ mock_get_reward.reset_mock()
+ mock_get_reward.return_value = 2.0 # Rewards > 1 should be clipped.
+ env.step(action)
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=1.0)
+
+ mock_get_reward.reset_mock()
+ mock_get_reward.return_value = 0.25
+ env.reset()
+ mock_get_reward.assert_called_with(env.physics)
+ self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_task_supports_environment_kwargs(self, domain, task):
+ env = suite.load(domain, task,
+ environment_kwargs=dict(flat_observation=True))
+ # Check that the kwargs are actually passed through to the environment.
+ self.assertSetEqual(set(env.observation_spec()),
+ {control.FLAT_OBSERVATION_KEY})
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_observation_arrays_dont_share_memory(self, domain, task):
+ env = suite.load(domain, task)
+ first_timestep = env.reset()
+ action = np.zeros(env.action_spec().shape)
+ second_timestep = env.step(action)
+ for name, first_array in first_timestep.observation.items():
+ second_array = second_timestep.observation[name]
+ self.assertFalse(
+ np.may_share_memory(first_array, second_array),
+ msg='Consecutive observations of {!r} may share memory.'.format(name))
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_observations_dont_contain_constant_elements(self, domain, task):
+ env = suite.load(domain, task)
+ trajectory = make_trajectory(domain=domain, task=task, seed=0,
+ num_episodes=2, max_steps_per_episode=1000)
+ observations = {name: [] for name in env.observation_spec()}
+ for time_step in trajectory:
+ for name, array in time_step.observation.items():
+ observations[name].append(array)
+
+ failures = []
+
+ for name, array_list in observations.items():
+ # Sampling random uniform actions generally isn't sufficient to trigger
+ # these touch sensors.
+ if (domain in ('manipulator', 'stacker', 'finger') and name == 'touch' or
+ domain == 'quadruped' and name == 'force_torque'):
+ continue
+ stacked_arrays = np.array(array_list)
+ is_constant = np.all(stacked_arrays == stacked_arrays[0], axis=0)
+ has_constant_elements = (
+ is_constant if np.isscalar(is_constant) else np.any(is_constant))
+ if has_constant_elements:
+ failures.append((name, is_constant))
+
+ self.assertEmpty(
+ failures,
+ msg='The following observation(s) contain constant elements:\n{}'
+ .format('\n'.join(':\t'.join([name, str(is_constant)])
+ for (name, is_constant) in failures)))
+
+ @parameterized.parameters(_DOMAINS_AND_TASKS)
+ def test_initial_state_is_randomized(self, domain, task):
+ env = suite.load(domain, task, task_kwargs={'random': 42})
+ obs1 = env.reset().observation
+ obs2 = env.reset().observation
+ self.assertFalse(
+ all(np.all(obs1[k] == obs2[k]) for k in obs1),
+ 'Two consecutive initial states have identical observations.\n'
+ 'First: {}\nSecond: {}'.format(obs1, obs2))
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/swimmer.py b/dm_control/suite/swimmer.py
index 057004c5..5fdcb735 100644
--- a/dm_control/suite/swimmer.py
+++ b/dm_control/suite/swimmer.py
@@ -14,14 +14,8 @@
# ============================================================================
"""Procedurally generated Swimmer domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -29,10 +23,8 @@
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
-
from lxml import etree
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
_DEFAULT_TIME_LIMIT = 30
_CONTROL_TIMESTEP = .03 # (Seconds)
@@ -54,30 +46,38 @@ def get_model_and_assets(n_joints):
@SUITE.add('benchmarking')
-def swimmer6(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def swimmer6(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns a 6-link swimmer."""
- return _make_swimmer(6, time_limit, random=random)
+ return _make_swimmer(6, time_limit, random=random,
+ environment_kwargs=environment_kwargs)
@SUITE.add('benchmarking')
-def swimmer15(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def swimmer15(time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns a 15-link swimmer."""
- return _make_swimmer(15, time_limit, random=random)
+ return _make_swimmer(15, time_limit, random=random,
+ environment_kwargs=environment_kwargs)
def swimmer(n_links=3, time_limit=_DEFAULT_TIME_LIMIT,
- random=None):
+ random=None, environment_kwargs=None):
"""Returns a swimmer with n links."""
- return _make_swimmer(n_links, time_limit, random=random)
+ return _make_swimmer(n_links, time_limit, random=random,
+ environment_kwargs=environment_kwargs)
-def _make_swimmer(n_joints, time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def _make_swimmer(n_joints, time_limit=_DEFAULT_TIME_LIMIT, random=None,
+ environment_kwargs=None):
"""Returns a swimmer control environment."""
model_string, assets = get_model_and_assets(n_joints)
physics = Physics.from_xml_string(model_string, assets=assets)
task = Swimmer(random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
def _make_model(n_bodies):
@@ -90,7 +90,7 @@ def _make_model(n_bodies):
sensor = etree.SubElement(mjcf, 'sensor')
parent = head_body
- for body_index in xrange(n_bodies - 1):
+ for body_index in range(n_bodies - 1):
site_name = 'site_{}'.format(body_index)
child = _make_body(body_index=body_index)
child.append(etree.Element('site', name=site_name))
@@ -155,7 +155,7 @@ def body_velocities(self):
def joints(self):
"""Returns all internal joint angles (excluding root joints)."""
- return self.data.qpos[3:]
+ return self.data.qpos[3:].copy()
class Swimmer(base.Task):
@@ -169,7 +169,7 @@ def __init__(self, random=None):
integer seed for creating a new `RandomState`, or None to select a seed
automatically (default).
"""
- super(Swimmer, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode.
@@ -191,6 +191,8 @@ def initialize_episode(self, physics):
physics.named.model.light_pos['target_light', 'x'] = xpos
physics.named.model.light_pos['target_light', 'y'] = ypos
+ super().initialize_episode(physics)
+
def get_observation(self, physics):
"""Returns an observation of joint angles, body velocities and target."""
obs = collections.OrderedDict()
diff --git a/dm_control/suite/tests/domains_test.py b/dm_control/suite/tests/domains_test.py
deleted file mode 100644
index 933ec40b..00000000
--- a/dm_control/suite/tests/domains_test.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Tests for dm_control.suite domains."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
-from absl.testing import absltest
-from absl.testing import parameterized
-
-from dm_control import suite
-
-import numpy as np
-import six
-
-_NUM_EPISODES = 5
-_NUM_STEPS_PER_EPISODE = 10
-
-
-class DomainTest(parameterized.TestCase):
- """Tests run on all the tasks registered."""
-
- def test_constants(self):
- num_tasks = sum(len(tasks) for tasks in
- six.itervalues(suite.TASKS_BY_DOMAIN))
-
- self.assertEqual(len(suite.ALL_TASKS), num_tasks)
-
- def _validate_observation(self, observation_dict, observation_spec):
- obs = observation_dict.copy()
- for name, spec in six.iteritems(observation_spec):
- arr = obs.pop(name)
- self.assertEqual(arr.shape, spec.shape)
- self.assertEqual(arr.dtype, spec.dtype)
- self.assertTrue(
- np.all(np.isfinite(arr)),
- msg='{!r} has non-finite value(s): {!r}'.format(name, arr))
- self.assertEmpty(
- obs,
- msg='Observation contains arrays(s) that are not in the spec: {!r}'
- .format(obs))
-
- def _validate_reward_range(self, time_step):
- if time_step.first():
- self.assertIsNone(time_step.reward)
- else:
- self.assertIsInstance(time_step.reward, float)
- self.assertBetween(time_step.reward, 0, 1)
-
- def _validate_discount(self, time_step):
- if time_step.first():
- self.assertIsNone(time_step.discount)
- else:
- self.assertIsInstance(time_step.discount, float)
- self.assertBetween(time_step.discount, 0, 1)
-
- def _validate_control_range(self, lower_bounds, upper_bounds):
- for b in lower_bounds:
- self.assertEqual(b, -1.0)
- for b in upper_bounds:
- self.assertEqual(b, 1.0)
-
- @parameterized.parameters(*suite.ALL_TASKS)
- def test_components_have_names(self, domain, task):
- env = suite.load(domain, task)
- model = env.physics.model
-
- object_types_and_size_fields = {
- 'body': 'nbody',
- 'joint': 'njnt',
- 'geom': 'ngeom',
- 'site': 'nsite',
- 'camera': 'ncam',
- 'light': 'nlight',
- 'mesh': 'nmesh',
- 'hfield': 'nhfield',
- 'texture': 'ntex',
- 'material': 'nmat',
- 'equality': 'neq',
- 'tendon': 'ntendon',
- 'actuator': 'nu',
- 'sensor': 'nsensor',
- 'numeric': 'nnumeric',
- 'text': 'ntext',
- 'tuple': 'ntuple',
- }
-
- for object_type, size_field in six.iteritems(object_types_and_size_fields):
- for idx in range(getattr(model, size_field)):
- object_name = model.id2name(idx, object_type)
- self.assertNotEqual(object_name, '',
- msg='Model {!r} contains unnamed {!r} with ID {}.'
- .format(model.name, object_type, idx))
-
- @parameterized.parameters(*suite.ALL_TASKS)
- def test_task_runs(self, domain, task):
- """Tests task runs correctly and observation is coherent with spec."""
- is_benchmark = (domain, task) in suite.BENCHMARKING
- env = suite.load(domain, task)
-
- observation_spec = env.observation_spec()
- action_spec = env.action_spec()
- model = env.physics.model
-
- # Check cameras.
- self.assertGreaterEqual(model.ncam, 2, 'Model {!r} should have at least 2 '
- 'cameras, has {!r}.'.format(model.name, model.ncam))
-
- # Check action bounds.
- lower_bounds = action_spec.minimum
- upper_bounds = action_spec.maximum
-
- if is_benchmark:
- self._validate_control_range(lower_bounds, upper_bounds)
-
- lower_bounds = np.where(np.isinf(lower_bounds), -1.0, lower_bounds)
- upper_bounds = np.where(np.isinf(upper_bounds), 1.0, upper_bounds)
-
- # Run a partial episode, check observations, rewards, discount.
- for _ in range(_NUM_EPISODES):
- time_step = env.reset()
- for _ in range(_NUM_STEPS_PER_EPISODE):
- self._validate_observation(time_step.observation, observation_spec)
- if is_benchmark:
- self._validate_reward_range(time_step)
- self._validate_discount(time_step)
- action = np.random.uniform(lower_bounds, upper_bounds)
- time_step = env.step(action)
-
- @parameterized.parameters(*suite.ALL_TASKS)
- def test_visualize_reward(self, domain, task):
- env = suite.load(domain, task)
- env.task.visualise_reward = True
- env.reset()
- action = np.zeros(env.action_spec().shape)
- for _ in range(2):
- env.step(action)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/dm_control/suite/utils/parse_amc.py b/dm_control/suite/utils/parse_amc.py
index 24ea2e53..0614c37b 100644
--- a/dm_control/suite/utils/parse_amc.py
+++ b/dm_control/suite/utils/parse_amc.py
@@ -15,24 +15,12 @@
"""Parse and convert amc motion capture data."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
-from dm_control.mujoco.wrapper import mjbindings
-
+from dm_control.mujoco import math as mjmath
import numpy as np
-
from scipy import interpolate
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-mjlib = mjbindings.mjlib
-
MOCAP_DT = 1.0/120.0
CONVERSION_LENGTH = 0.056444
@@ -98,7 +86,7 @@ def convert(file_name, physics, timestep):
while time_vals_new[-1] > time_vals[-1]:
time_vals_new = time_vals_new[:-1]
- for i in xrange(qpos_values.shape[1]):
+ for i in range(qpos_values.shape[1]):
f = interpolate.splrep(time_vals, qpos_values[:, i])
qpos_values_resampled.append(interpolate.splev(time_vals_new, f))
@@ -109,7 +97,8 @@ def convert(file_name, physics, timestep):
p_tp1 = qpos_values_resampled[:, t + 1]
p_t = qpos_values_resampled[:, t]
qvel = [(p_tp1[:3]-p_t[:3])/ timestep,
- mj_quat2vel(mj_quatdiff(p_t[3:7], p_tp1[3:7]), timestep),
+ mjmath.mj_quat2vel(
+ mjmath.mj_quatdiff(p_t[3:7], p_tp1[3:7]), timestep),
(p_tp1[7:]-p_t[7:])/ timestep]
qvel_list.append(np.concatenate(qvel))
@@ -134,7 +123,7 @@ def parse(file_name):
while True:
line = fid.readline().strip()
if not line or line == str(frame_ind):
- values.append(np.array(frame_vals, dtype=np.float))
+ values.append(np.array(frame_vals, dtype=float))
break
tokens = line.split()
frame_vals.extend(tokens[1:])
@@ -145,7 +134,7 @@ def parse(file_name):
while True:
line = fid.readline().strip()
if not line or line == str(frame_ind):
- values.append(np.array(frame_vals, dtype=np.float))
+ values.append(np.array(frame_vals, dtype=float))
break
tokens = line.split()
frame_vals.extend(tokens[1:])
@@ -156,7 +145,7 @@ def parse(file_name):
return values
-class Amcvals2qpos(object):
+class Amcvals2qpos:
"""Callable that converts .amc values for a frame and to MuJoCo qpos format.
"""
@@ -174,8 +163,8 @@ def __init__(self, index2joint, joint_order):
[[1, 0, 0], [0, 0, -1], [0, 1, 0]]) * CONVERSION_LENGTH
self.qpos_root_quat_ind = [3, 4, 5, 6]
amc2qpos_transform = np.zeros((len(index2joint), len(joint_order)))
- for i in xrange(len(index2joint)):
- for j in xrange(len(joint_order)):
+ for i in range(len(index2joint)):
+ for j in range(len(joint_order)):
if index2joint[i] == joint_order[j]:
if 'rx' in index2joint[i]:
amc2qpos_transform[i][j] = 1
@@ -192,65 +181,10 @@ def __call__(self, amc_val):
# Root.
qpos[:3] = np.dot(self.root_xyz_ransform, amc_val[:3])
- qpos_quat = euler2quat(amc_val[3], amc_val[4], amc_val[5])
- qpos_quat = mj_quatprod(euler2quat(90, 0, 0), qpos_quat)
+ qpos_quat = mjmath.euler2quat(amc_val[3], amc_val[4], amc_val[5])
+ qpos_quat = mjmath.mj_quatprod(mjmath.euler2quat(90, 0, 0), qpos_quat)
for i, ind in enumerate(self.qpos_root_quat_ind):
qpos[ind] = qpos_quat[i]
return qpos
-
-
-def euler2quat(ax, ay, az):
- """Converts euler angles to a quaternion.
-
- Note: rotation order is zyx
-
- Args:
- ax: Roll angle (deg)
- ay: Pitch angle (deg).
- az: Yaw angle (deg).
-
- Returns:
- A numpy array representing the rotation as a quaternion.
- """
- r1 = az
- r2 = ay
- r3 = ax
-
- c1 = np.cos(np.deg2rad(r1 / 2))
- s1 = np.sin(np.deg2rad(r1 / 2))
- c2 = np.cos(np.deg2rad(r2 / 2))
- s2 = np.sin(np.deg2rad(r2 / 2))
- c3 = np.cos(np.deg2rad(r3 / 2))
- s3 = np.sin(np.deg2rad(r3 / 2))
-
- q0 = c1 * c2 * c3 + s1 * s2 * s3
- q1 = c1 * c2 * s3 - s1 * s2 * c3
- q2 = c1 * s2 * c3 + s1 * c2 * s3
- q3 = s1 * c2 * c3 - c1 * s2 * s3
-
- return np.array([q0, q1, q2, q3])
-
-
-def mj_quatprod(q, r):
- quaternion = np.zeros(4)
- mjlib.mju_mulQuat(quaternion, np.ascontiguousarray(q),
- np.ascontiguousarray(r))
- return quaternion
-
-
-def mj_quat2vel(q, dt):
- vel = np.zeros(3)
- mjlib.mju_quat2Vel(vel, np.ascontiguousarray(q), dt)
- return vel
-
-
-def mj_quatneg(q):
- quaternion = np.zeros(4)
- mjlib.mju_negQuat(quaternion, np.ascontiguousarray(q))
- return quaternion
-
-
-def mj_quatdiff(source, target):
- return mj_quatprod(mj_quatneg(source), np.ascontiguousarray(target))
diff --git a/dm_control/suite/utils/parse_amc_test.py b/dm_control/suite/utils/parse_amc_test.py
index 1d2f40e2..efc75946 100644
--- a/dm_control/suite/utils/parse_amc_test.py
+++ b/dm_control/suite/utils/parse_amc_test.py
@@ -15,19 +15,13 @@
"""Tests for parse_amc utility."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import os
-# Internal dependencies.
-
from absl.testing import absltest
from dm_control.suite import humanoid_CMU
from dm_control.suite.utils import parse_amc
-from dm_control.utils import resources
+from dm_control.utils import io as resources
_TEST_AMC_PATH = resources.GetResourceFilename(
os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
diff --git a/dm_control/suite/utils/randomizers.py b/dm_control/suite/utils/randomizers.py
index df557ae1..bc8f404e 100644
--- a/dm_control/suite/utils/randomizers.py
+++ b/dm_control/suite/utils/randomizers.py
@@ -15,16 +15,9 @@
"""Randomization functions."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
from dm_control.mujoco.wrapper import mjbindings
-
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
def random_limited_quaternion(random, limit):
@@ -65,7 +58,7 @@ def randomize_limited_and_rotational_joints(physics, random=None):
qpos = physics.named.data.qpos
- for joint_id in xrange(physics.model.njnt):
+ for joint_id in range(physics.model.njnt):
joint_name = physics.model.id2name(joint_id, 'joint')
joint_type = physics.model.jnt_type[joint_id]
is_limited = physics.model.jnt_limited[joint_id]
@@ -88,6 +81,8 @@ def randomize_limited_and_rotational_joints(physics, random=None):
qpos[joint_name] = quat
elif joint_type == free:
+ # this should be random.randn, but changing it now could significantly
+ # affect benchmark results.
quat = random.rand(4)
quat /= np.linalg.norm(quat)
qpos[joint_name][3:] = quat
diff --git a/dm_control/suite/utils/randomizers_test.py b/dm_control/suite/utils/randomizers_test.py
index f6dd6a1d..6b52c43d 100644
--- a/dm_control/suite/utils/randomizers_test.py
+++ b/dm_control/suite/utils/randomizers_test.py
@@ -13,34 +13,27 @@
# limitations under the License.
# ============================================================================
-"""Tests for randomizers.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
from absl.testing import absltest
from absl.testing import parameterized
-
-from dm_control.mujoco import engine
-from dm_control.mujoco.wrapper.mjbindings import mjlib
+from dm_control import mujoco
+from dm_control.mujoco.wrapper import mjbindings
from dm_control.suite.utils import randomizers
-
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
+
+mjlib = mjbindings.mjlib
class RandomizeUnlimitedJointsTest(parameterized.TestCase):
def setUp(self):
+ super().setUp()
self.rand = np.random.RandomState(100)
def test_single_joint_of_each_type(self):
- physics = engine.Physics.from_xml_string("""
+ physics = mujoco.Physics.from_xml_string("""
-
+
@@ -50,13 +43,13 @@ def test_single_joint_of_each_type(self):
-
+
-
+
-
+
@@ -81,7 +74,7 @@ def test_single_joint_of_each_type(self):
self.assertEqual(0., np.sum(physics.named.data.qpos['free'][:3]))
def test_multiple_joints_of_same_type(self):
- physics = engine.Physics.from_xml_string("""
+ physics = mujoco.Physics.from_xml_string("""
@@ -107,7 +100,7 @@ def test_multiple_joints_of_same_type(self):
physics.named.data.qpos['hinge_3'])
def test_unlimited_hinge_randomization_range(self):
- physics = engine.Physics.from_xml_string("""
+ physics = mujoco.Physics.from_xml_string("""
@@ -116,12 +109,12 @@ def test_unlimited_hinge_randomization_range(self):
""")
- for _ in xrange(10):
+ for _ in range(10):
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
self.assertBetween(physics.named.data.qpos['hinge'], -np.pi, np.pi)
def test_limited_1d_joint_limits_are_respected(self):
- physics = engine.Physics.from_xml_string("""
+ physics = mujoco.Physics.from_xml_string("""
@@ -134,14 +127,14 @@ def test_limited_1d_joint_limits_are_respected(self):
""")
- for _ in xrange(10):
+ for _ in range(10):
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
self.assertBetween(physics.named.data.qpos['hinge'],
np.deg2rad(0), np.deg2rad(10))
self.assertBetween(physics.named.data.qpos['slide'], 30, 50)
def test_limited_ball_joint_are_respected(self):
- physics = engine.Physics.from_xml_string("""
+ physics = mujoco.Physics.from_xml_string("""
@@ -152,7 +145,7 @@ def test_limited_ball_joint_are_respected(self):
body_axis = np.array([1., 0., 0.])
joint_axis = np.zeros(3)
- for _ in xrange(10):
+ for _ in range(10):
randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
quat = physics.named.data.qpos['ball']
diff --git a/dm_control/suite/walker.py b/dm_control/suite/walker.py
index 9aedf9b7..a56eba91 100644
--- a/dm_control/suite/walker.py
+++ b/dm_control/suite/walker.py
@@ -15,14 +15,8 @@
"""Planar Walker Domain."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-# Internal dependencies.
-
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
@@ -52,30 +46,36 @@ def get_model_and_assets():
@SUITE.add('benchmarking')
-def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Stand task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = PlanarWalker(move_speed=0, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Walk task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = PlanarWalker(move_speed=_WALK_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
@SUITE.add('benchmarking')
-def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
"""Returns the Run task."""
physics = Physics.from_xml_string(*get_model_and_assets())
task = PlanarWalker(move_speed=_RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
return control.Environment(
- physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
class Physics(mujoco.Physics):
@@ -91,7 +91,7 @@ def torso_height(self):
def horizontal_velocity(self):
"""Returns the horizontal velocity of the center-of-mass."""
- return self.named.data.subtree_linvel['torso', 'x']
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
def orientations(self):
"""Returns planar orientations of all bodies."""
@@ -113,7 +113,7 @@ def __init__(self, move_speed, random=None):
automatically (default).
"""
self._move_speed = move_speed
- super(PlanarWalker, self).__init__(random=random)
+ super().__init__(random=random)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode.
@@ -126,6 +126,7 @@ def initialize_episode(self, physics):
"""
randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ super().initialize_episode(physics)
def get_observation(self, physics):
"""Returns an observation of body orientations, height and velocites."""
diff --git a/dm_control/suite/walker.xml b/dm_control/suite/walker.xml
index b9072c23..36e85d79 100644
--- a/dm_control/suite/walker.xml
+++ b/dm_control/suite/walker.xml
@@ -55,6 +55,10 @@
+
+
+
+
diff --git a/dm_control/suite/wrappers/action_noise.py b/dm_control/suite/wrappers/action_noise.py
new file mode 100644
index 00000000..d5a114d9
--- /dev/null
+++ b/dm_control/suite/wrappers/action_noise.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Wrapper control suite environments that adds Gaussian noise to actions."""
+
+import dm_env
+import numpy as np
+
+
+_BOUNDS_MUST_BE_FINITE = (
+ 'All bounds in `env.action_spec()` must be finite, got: {action_spec}')
+
+
+class Wrapper(dm_env.Environment):
+ """Wraps a control environment and adds Gaussian noise to actions."""
+
+ def __init__(self, env, scale=0.01):
+ """Initializes a new action noise Wrapper.
+
+ Args:
+ env: The control suite environment to wrap.
+ scale: The standard deviation of the noise, expressed as a fraction
+ of the max-min range for each action dimension.
+
+ Raises:
+ ValueError: If any of the action dimensions of the wrapped environment are
+ unbounded.
+ """
+ action_spec = env.action_spec()
+ if not (np.all(np.isfinite(action_spec.minimum)) and
+ np.all(np.isfinite(action_spec.maximum))):
+ raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec))
+ self._minimum = action_spec.minimum
+ self._maximum = action_spec.maximum
+ self._noise_std = scale * (action_spec.maximum - action_spec.minimum)
+ self._env = env
+
+ def step(self, action):
+ noisy_action = action + self._env.task.random.normal(scale=self._noise_std)
+ # Clip the noisy actions in place so that they fall within the bounds
+ # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of-
+ # bounds control inputs, but we also clip here in case the actions do not
+ # correspond directly to MuJoCo actuators, or if there are other wrapper
+ # layers that expect the actions to be within bounds.
+ np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action)
+ return self._env.step(noisy_action)
+
+ def reset(self):
+ return self._env.reset()
+
+ def observation_spec(self):
+ return self._env.observation_spec()
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
diff --git a/dm_control/suite/wrappers/action_noise_test.py b/dm_control/suite/wrappers/action_noise_test.py
new file mode 100644
index 00000000..d0be4ea2
--- /dev/null
+++ b/dm_control/suite/wrappers/action_noise_test.py
@@ -0,0 +1,131 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the action noise wrapper."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.rl import control
+from dm_control.suite.wrappers import action_noise
+from dm_env import specs
+import mock
+import numpy as np
+
+
+class ActionNoiseTest(parameterized.TestCase):
+
+ def make_action_spec(self, lower=(-1.,), upper=(1.,)):
+ lower, upper = np.broadcast_arrays(lower, upper)
+ return specs.BoundedArray(
+ shape=lower.shape, dtype=float, minimum=lower, maximum=upper)
+
+ def make_mock_env(self, action_spec=None):
+ action_spec = action_spec or self.make_action_spec()
+ env = mock.Mock(spec=control.Environment)
+ env.action_spec.return_value = action_spec
+ return env
+
+ def assertStepCalledOnceWithCorrectAction(self, env, expected_action):
+ # NB: `assert_called_once_with()` doesn't support numpy arrays.
+ env.step.assert_called_once()
+ actual_action = env.step.call_args_list[0][0][0]
+ np.testing.assert_array_equal(expected_action, actual_action)
+
+ @parameterized.parameters([
+ dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.05),
+ dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.),
+ dict(lower=np.r_[-1., 0.], upper=np.r_[-1., 0.], scale=0.05),
+ ])
+ def test_step(self, lower, upper, scale):
+ seed = 0
+ std = scale * (upper - lower)
+ expected_noise = np.random.RandomState(seed).normal(scale=std)
+ action = np.random.RandomState(seed).uniform(lower, upper)
+ expected_noisy_action = np.clip(action + expected_noise, lower, upper)
+ task = mock.Mock(spec=control.Task)
+ task.random = np.random.RandomState(seed)
+ action_spec = self.make_action_spec(lower=lower, upper=upper)
+ env = self.make_mock_env(action_spec=action_spec)
+ env.task = task
+ wrapped_env = action_noise.Wrapper(env, scale=scale)
+ time_step = wrapped_env.step(action)
+ self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action)
+ self.assertIs(time_step, env.step(expected_noisy_action))
+
+ @parameterized.named_parameters([
+ dict(testcase_name='within_bounds', action=np.r_[-1.], noise=np.r_[0.1]),
+ dict(testcase_name='below_lower', action=np.r_[-1.], noise=np.r_[-0.1]),
+ dict(testcase_name='above_upper', action=np.r_[1.], noise=np.r_[0.1]),
+ ])
+ def test_action_clipping(self, action, noise):
+ lower = -1.
+ upper = 1.
+ expected_noisy_action = np.clip(action + noise, lower, upper)
+ task = mock.Mock(spec=control.Task)
+ task.random = mock.Mock(spec=np.random.RandomState)
+ task.random.normal.return_value = noise
+ action_spec = self.make_action_spec(lower=lower, upper=upper)
+ env = self.make_mock_env(action_spec=action_spec)
+ env.task = task
+ wrapped_env = action_noise.Wrapper(env)
+ time_step = wrapped_env.step(action)
+ self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action)
+ self.assertIs(time_step, env.step(expected_noisy_action))
+
+ @parameterized.parameters([
+ dict(lower=np.r_[-1., 0.], upper=np.r_[1., np.inf]),
+ dict(lower=np.r_[np.nan, 0.], upper=np.r_[1., 2.]),
+ ])
+ def test_error_if_action_bounds_non_finite(self, lower, upper):
+ action_spec = self.make_action_spec(lower=lower, upper=upper)
+ env = self.make_mock_env(action_spec=action_spec)
+ with self.assertRaisesWithLiteralMatch(
+ ValueError,
+ action_noise._BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)):
+ _ = action_noise.Wrapper(env)
+
+ def test_reset(self):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ time_step = wrapped_env.reset()
+ env.reset.assert_called_once_with()
+ self.assertIs(time_step, env.reset())
+
+ def test_observation_spec(self):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ observation_spec = wrapped_env.observation_spec()
+ env.observation_spec.assert_called_once_with()
+ self.assertIs(observation_spec, env.observation_spec())
+
+ def test_action_spec(self):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ # `env.action_spec()` is called in `Wrapper.__init__()`
+ env.action_spec.reset_mock()
+ action_spec = wrapped_env.action_spec()
+ env.action_spec.assert_called_once_with()
+ self.assertIs(action_spec, env.action_spec())
+
+ @parameterized.parameters(['task', 'physics', 'control_timestep'])
+ def test_getattr(self, attribute_name):
+ env = self.make_mock_env()
+ wrapped_env = action_noise.Wrapper(env)
+ attr = getattr(wrapped_env, attribute_name)
+ self.assertIs(attr, getattr(env, attribute_name))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/wrappers/action_scale.py b/dm_control/suite/wrappers/action_scale.py
new file mode 100644
index 00000000..2dfef133
--- /dev/null
+++ b/dm_control/suite/wrappers/action_scale.py
@@ -0,0 +1,103 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Wrapper that scales actions to a specific range."""
+
+import dm_env
+from dm_env import specs
+import numpy as np
+
+_ACTION_SPEC_MUST_BE_BOUNDED_ARRAY = (
+ "`env.action_spec()` must return a single `BoundedArray`, got: {}.")
+_MUST_BE_FINITE = "All values in `{name}` must be finite, got: {bounds}."
+_MUST_BROADCAST = (
+ "`{name}` must be broadcastable to shape {shape}, got: {bounds}.")
+
+
+class Wrapper(dm_env.Environment):
+ """Wraps a control environment to rescale actions to a specific range."""
+ __slots__ = ("_action_spec", "_env", "_transform")
+
+ def __init__(self, env, minimum, maximum):
+ """Initializes a new action scale Wrapper.
+
+ Args:
+ env: Instance of `dm_env.Environment` to wrap. Its `action_spec` must
+ consist of a single `BoundedArray` with all-finite bounds.
+ minimum: Scalar or array-like specifying element-wise lower bounds
+ (inclusive) for the `action_spec` of the wrapped environment. Must be
+ finite and broadcastable to the shape of the `action_spec`.
+ maximum: Scalar or array-like specifying element-wise upper bounds
+ (inclusive) for the `action_spec` of the wrapped environment. Must be
+ finite and broadcastable to the shape of the `action_spec`.
+
+ Raises:
+ ValueError: If `env.action_spec()` is not a single `BoundedArray`.
+ ValueError: If `env.action_spec()` has non-finite bounds.
+ ValueError: If `minimum` or `maximum` contain non-finite values.
+ ValueError: If `minimum` or `maximum` are not broadcastable to
+ `env.action_spec().shape`.
+ """
+ action_spec = env.action_spec()
+ if not isinstance(action_spec, specs.BoundedArray):
+ raise ValueError(_ACTION_SPEC_MUST_BE_BOUNDED_ARRAY.format(action_spec))
+
+ minimum = np.array(minimum)
+ maximum = np.array(maximum)
+ shape = action_spec.shape
+ orig_minimum = action_spec.minimum
+ orig_maximum = action_spec.maximum
+ orig_dtype = action_spec.dtype
+
+ def validate(bounds, name):
+ if not np.all(np.isfinite(bounds)):
+ raise ValueError(_MUST_BE_FINITE.format(name=name, bounds=bounds))
+ try:
+ np.broadcast_to(bounds, shape)
+ except ValueError:
+ raise ValueError(_MUST_BROADCAST.format(
+ name=name, bounds=bounds, shape=shape))
+
+ validate(minimum, "minimum")
+ validate(maximum, "maximum")
+ validate(orig_minimum, "env.action_spec().minimum")
+ validate(orig_maximum, "env.action_spec().maximum")
+
+ scale = (orig_maximum - orig_minimum) / (maximum - minimum)
+
+ def transform(action):
+ new_action = orig_minimum + scale * (action - minimum)
+ return new_action.astype(orig_dtype, copy=False)
+
+ dtype = np.result_type(minimum, maximum, orig_dtype)
+ self._action_spec = action_spec.replace(
+ minimum=minimum, maximum=maximum, dtype=dtype)
+ self._env = env
+ self._transform = transform
+
+ def step(self, action):
+ return self._env.step(self._transform(action))
+
+ def reset(self):
+ return self._env.reset()
+
+ def observation_spec(self):
+ return self._env.observation_spec()
+
+ def action_spec(self):
+ return self._action_spec
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
diff --git a/dm_control/suite/wrappers/action_scale_test.py b/dm_control/suite/wrappers/action_scale_test.py
new file mode 100644
index 00000000..144cf43a
--- /dev/null
+++ b/dm_control/suite/wrappers/action_scale_test.py
@@ -0,0 +1,157 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the action scale wrapper."""
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.rl import control
+from dm_control.suite.wrappers import action_scale
+from dm_env import specs
+import mock
+import numpy as np
+
+
+def make_action_spec(lower=(-1.,), upper=(1.,)):
+ lower, upper = np.broadcast_arrays(lower, upper)
+ return specs.BoundedArray(
+ shape=lower.shape, dtype=float, minimum=lower, maximum=upper)
+
+
+def make_mock_env(action_spec):
+ env = mock.Mock(spec=control.Environment)
+ env.action_spec.return_value = action_spec
+ return env
+
+
+class ActionScaleTest(parameterized.TestCase):
+
+ def assertStepCalledOnceWithCorrectAction(self, env, expected_action):
+ # NB: `assert_called_once_with()` doesn't support numpy arrays.
+ env.step.assert_called_once()
+ actual_action = env.step.call_args_list[0][0][0]
+ np.testing.assert_array_equal(expected_action, actual_action)
+
+ @parameterized.parameters(
+ {
+ 'minimum': np.r_[-1., -1.],
+ 'maximum': np.r_[1., 1.],
+ 'scaled_minimum': np.r_[-2., -2.],
+ 'scaled_maximum': np.r_[2., 2.],
+ },
+ {
+ 'minimum': np.r_[-2., -2.],
+ 'maximum': np.r_[2., 2.],
+ 'scaled_minimum': np.r_[-1., -1.],
+ 'scaled_maximum': np.r_[1., 1.],
+ },
+ {
+ 'minimum': np.r_[-1., -1.],
+ 'maximum': np.r_[1., 1.],
+ 'scaled_minimum': np.r_[-2., -2.],
+ 'scaled_maximum': np.r_[1., 1.],
+ },
+ {
+ 'minimum': np.r_[-1., -1.],
+ 'maximum': np.r_[1., 1.],
+ 'scaled_minimum': np.r_[-1., -1.],
+ 'scaled_maximum': np.r_[2., 2.],
+ },
+ )
+ def test_step(self, minimum, maximum, scaled_minimum, scaled_maximum):
+ action_spec = make_action_spec(lower=minimum, upper=maximum)
+ env = make_mock_env(action_spec=action_spec)
+ wrapped_env = action_scale.Wrapper(
+ env, minimum=scaled_minimum, maximum=scaled_maximum)
+
+ time_step = wrapped_env.step(scaled_minimum)
+ self.assertStepCalledOnceWithCorrectAction(env, minimum)
+ self.assertIs(time_step, env.step(minimum))
+
+ env.reset_mock()
+
+ time_step = wrapped_env.step(scaled_maximum)
+ self.assertStepCalledOnceWithCorrectAction(env, maximum)
+ self.assertIs(time_step, env.step(maximum))
+
+ @parameterized.parameters(
+ {
+ 'minimum': np.r_[-1., -1.],
+ 'maximum': np.r_[1., 1.],
+ },
+ {
+ 'minimum': np.r_[0, 1],
+ 'maximum': np.r_[2, 3],
+ },
+ )
+ def test_correct_action_spec(self, minimum, maximum):
+ original_action_spec = make_action_spec(
+ lower=np.r_[-2., -2.], upper=np.r_[2., 2.])
+ env = make_mock_env(action_spec=original_action_spec)
+ wrapped_env = action_scale.Wrapper(env, minimum=minimum, maximum=maximum)
+ new_action_spec = wrapped_env.action_spec()
+ np.testing.assert_array_equal(new_action_spec.minimum, minimum)
+ np.testing.assert_array_equal(new_action_spec.maximum, maximum)
+
+ @parameterized.parameters('reset', 'observation_spec', 'control_timestep')
+ def test_method_delegated_to_underlying_env(self, method_name):
+ env = make_mock_env(action_spec=make_action_spec())
+ wrapped_env = action_scale.Wrapper(env, minimum=0, maximum=1)
+ env_method = getattr(env, method_name)
+ wrapper_method = getattr(wrapped_env, method_name)
+ out = wrapper_method()
+ env_method.assert_called_once_with()
+ self.assertIs(out, env_method())
+
+ def test_invalid_action_spec_type(self):
+ action_spec = [make_action_spec()] * 2
+ env = make_mock_env(action_spec=action_spec)
+ with self.assertRaisesWithLiteralMatch(
+ ValueError,
+ action_scale._ACTION_SPEC_MUST_BE_BOUNDED_ARRAY.format(action_spec)):
+ action_scale.Wrapper(env, minimum=0, maximum=1)
+
+ @parameterized.parameters(
+ {'name': 'minimum', 'bounds': np.r_[np.nan]},
+ {'name': 'minimum', 'bounds': np.r_[-np.inf]},
+ {'name': 'maximum', 'bounds': np.r_[np.inf]},
+ )
+ def test_non_finite_bounds(self, name, bounds):
+ kwargs = {'minimum': np.r_[-1.], 'maximum': np.r_[1.]}
+ kwargs[name] = bounds
+ env = make_mock_env(action_spec=make_action_spec())
+ with self.assertRaisesWithLiteralMatch(
+ ValueError,
+ action_scale._MUST_BE_FINITE.format(name=name, bounds=bounds)):
+ action_scale.Wrapper(env, **kwargs)
+
+ @parameterized.parameters(
+ {'name': 'minimum', 'bounds': np.r_[1., 2., 3.]},
+ {'name': 'minimum', 'bounds': np.r_[[1.], [2.], [3.]]},
+ )
+ def test_invalid_bounds_shape(self, name, bounds):
+ shape = (2,)
+ kwargs = {'minimum': np.zeros(shape), 'maximum': np.ones(shape)}
+ kwargs[name] = bounds
+ action_spec = make_action_spec(lower=[-1, -1], upper=[2, 3])
+ env = make_mock_env(action_spec=action_spec)
+ with self.assertRaisesWithLiteralMatch(
+ ValueError,
+ action_scale._MUST_BROADCAST.format(
+ name=name, bounds=bounds, shape=shape)):
+ action_scale.Wrapper(env, **kwargs)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/wrappers/mujoco_profiling.py b/dm_control/suite/wrappers/mujoco_profiling.py
new file mode 100644
index 00000000..4d6afe43
--- /dev/null
+++ b/dm_control/suite/wrappers/mujoco_profiling.py
@@ -0,0 +1,107 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Wrapper that adds pixel observations to a control environment."""
+
+import collections
+import dm_env
+from dm_env import specs
+import numpy as np
+
+STATE_KEY = 'state'
+
+
+class Wrapper(dm_env.Environment):
+ """Wraps a control environment and adds an observation with profile data.
+
+ The profile data describes the time Mujoco spent in a "step", and the
+ observation consists of two values: the cumulative time spent on steps
+ (in seconds), and the number of times the profiling timer was called.
+ """
+
+ def __init__(self, env, observation_key='step_timing'):
+ """Initializes a new mujoco_profiling Wrapper.
+
+ Args:
+ env: The environment to wrap.
+ observation_key: Optional custom string specifying the profile
+ observation's key in the `OrderedDict` of observations. Defaults to
+ 'step_timing'.
+
+ Raises:
+ ValueError: If `env`'s observation spec is not compatible with the
+ wrapper. Supported formats are a single array, or a dict of arrays.
+ ValueError: If `env`'s observation already contains the specified
+ `observation_key`.
+ """
+ wrapped_observation_spec = env.observation_spec()
+
+ if isinstance(wrapped_observation_spec, specs.Array):
+ self._observation_is_dict = False
+ invalid_keys = set([STATE_KEY])
+ elif isinstance(wrapped_observation_spec, collections.abc.MutableMapping):
+ self._observation_is_dict = True
+ invalid_keys = set(wrapped_observation_spec.keys())
+ else:
+ raise ValueError('Unsupported observation spec structure.')
+
+ if observation_key in invalid_keys:
+ raise ValueError(
+ 'Duplicate or reserved observation key {!r}.'.format(observation_key))
+
+ if self._observation_is_dict:
+ self._observation_spec = wrapped_observation_spec.copy()
+ else:
+ self._observation_spec = collections.OrderedDict()
+ self._observation_spec[STATE_KEY] = wrapped_observation_spec
+
+ env.physics.enable_profiling()
+
+ # Extend observation spec.
+ self._observation_spec[observation_key] = specs.Array(
+ shape=(2,), dtype=np.double, name=observation_key)
+
+ self._env = env
+ self._observation_key = observation_key
+
+ def reset(self):
+ return self._add_profile_observation(self._env.reset())
+
+ def step(self, action):
+ return self._add_profile_observation(self._env.step(action))
+
+ def observation_spec(self):
+ return self._observation_spec
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def _add_profile_observation(self, time_step):
+ if self._observation_is_dict:
+ observation = type(time_step.observation)(time_step.observation)
+ else:
+ observation = collections.OrderedDict()
+ observation[STATE_KEY] = time_step.observation
+
+ # timer[0] is the step timer. There are lots of different timers (see
+ # mujoco/include/mjdata.h)
+ # but we only care about the step timer.
+ timing = self._env.physics.data.timer[0]
+
+ observation[self._observation_key] = np.array(
+ [timing.duration, timing.number], dtype=np.double)
+ return time_step._replace(observation=observation)
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
diff --git a/dm_control/suite/wrappers/mujoco_profiling_test.py b/dm_control/suite/wrappers/mujoco_profiling_test.py
new file mode 100644
index 00000000..43186731
--- /dev/null
+++ b/dm_control/suite/wrappers/mujoco_profiling_test.py
@@ -0,0 +1,57 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the mujoco_profiling wrapper."""
+
+import collections
+
+from absl.testing import absltest
+from dm_control.suite import cartpole
+from dm_control.suite.wrappers import mujoco_profiling
+import numpy as np
+
+
+class MujocoProfilingTest(absltest.TestCase):
+
+ def test_dict_observation(self):
+ obs_key = 'mjprofile'
+
+ env = cartpole.swingup()
+
+ # Make sure we are testing the right environment for the test.
+ observation_spec = env.observation_spec()
+ self.assertIsInstance(observation_spec, collections.OrderedDict)
+
+ # The wrapper should only add one observation.
+ wrapped = mujoco_profiling.Wrapper(env, observation_key=obs_key)
+
+ wrapped_observation_spec = wrapped.observation_spec()
+ self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
+
+ expected_length = len(observation_spec) + 1
+ self.assertLen(wrapped_observation_spec, expected_length)
+ expected_keys = list(observation_spec.keys()) + [obs_key]
+ self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
+
+ # Check that the added spec item is consistent with the added observation.
+ time_step = wrapped.reset()
+ profile_observation = time_step.observation[obs_key]
+ wrapped_observation_spec[obs_key].validate(profile_observation)
+
+ self.assertEqual(profile_observation.shape, (2,))
+ self.assertEqual(profile_observation.dtype, np.double)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/wrappers/pixels.py b/dm_control/suite/wrappers/pixels.py
index da66865e..547fa767 100644
--- a/dm_control/suite/wrappers/pixels.py
+++ b/dm_control/suite/wrappers/pixels.py
@@ -15,21 +15,14 @@
"""Wrapper that adds pixel observations to a control environment."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-
-# Internal dependencies.
-
-from dm_control.rl import environment
-from dm_control.rl import specs
+import dm_env
+from dm_env import specs
STATE_KEY = 'state'
-class Wrapper(environment.Base):
+class Wrapper(dm_env.Environment):
"""Wraps a control environment and adds a rendered pixel observation."""
def __init__(self, env, pixels_only=True, render_kwargs=None,
@@ -59,10 +52,10 @@ def __init__(self, env, pixels_only=True, render_kwargs=None,
wrapped_observation_spec = env.observation_spec()
- if isinstance(wrapped_observation_spec, specs.ArraySpec):
+ if isinstance(wrapped_observation_spec, specs.Array):
self._observation_is_dict = False
invalid_keys = set([STATE_KEY])
- elif isinstance(wrapped_observation_spec, collections.MutableMapping):
+ elif isinstance(wrapped_observation_spec, collections.abc.MutableMapping):
self._observation_is_dict = True
invalid_keys = set(wrapped_observation_spec.keys())
else:
@@ -82,7 +75,7 @@ def __init__(self, env, pixels_only=True, render_kwargs=None,
# Extend observation spec.
pixels = env.physics.render(**render_kwargs)
- pixels_spec = specs.ArraySpec(
+ pixels_spec = specs.Array(
shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
self._observation_spec[observation_key] = pixels_spec
diff --git a/dm_control/suite/wrappers/pixels_test.py b/dm_control/suite/wrappers/pixels_test.py
index 39ab7bda..ccb2fd5d 100644
--- a/dm_control/suite/wrappers/pixels_test.py
+++ b/dm_control/suite/wrappers/pixels_test.py
@@ -15,28 +15,18 @@
"""Tests for the pixel wrapper."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-import unittest
-
-# Internal dependencies.
from absl.testing import absltest
from absl.testing import parameterized
-
-from dm_control import render
-from dm_control.rl import environment
-from dm_control.rl import specs
from dm_control.suite import cartpole
from dm_control.suite.wrappers import pixels
-
+import dm_env
+from dm_env import specs
import numpy as np
-class FakePhysics(object):
+class FakePhysics:
def render(self, *args, **kwargs):
del args
@@ -44,26 +34,25 @@ def render(self, *args, **kwargs):
return np.zeros((4, 5, 3), dtype=np.uint8)
-class FakeArrayObservationEnvironment(environment.Base):
+class FakeArrayObservationEnvironment(dm_env.Environment):
def __init__(self):
self.physics = FakePhysics()
def reset(self):
- return environment.restart(np.zeros((2,)))
+ return dm_env.restart(np.zeros((2,)))
def step(self, action):
del action
- return environment.transition(0.0, np.zeros((2,)))
+ return dm_env.transition(0.0, np.zeros((2,)))
def action_spec(self):
pass
def observation_spec(self):
- return specs.ArraySpec(shape=(2,), dtype=np.float)
+ return specs.Array(shape=(2,), dtype=float)
-@unittest.skipIf(render.DISABLED, reason=render.DISABLED_MESSAGE)
class PixelsTest(parameterized.TestCase):
@parameterized.parameters(True, False)
@@ -89,10 +78,11 @@ def test_dict_observation(self, pixels_only):
self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
if pixels_only:
- self.assertEqual(1, len(wrapped_observation_spec))
+ self.assertLen(wrapped_observation_spec, 1)
self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
else:
- self.assertEqual(len(observation_spec) + 1, len(wrapped_observation_spec))
+ expected_length = len(observation_spec) + 1
+ self.assertLen(wrapped_observation_spec, expected_length)
expected_keys = list(observation_spec.keys()) + [pixel_key]
self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
@@ -110,7 +100,7 @@ def test_single_array_observation(self, pixels_only):
env = FakeArrayObservationEnvironment()
observation_spec = env.observation_spec()
- self.assertIsInstance(observation_spec, specs.ArraySpec)
+ self.assertIsInstance(observation_spec, specs.Array)
wrapped = pixels.Wrapper(env, observation_key=pixel_key,
pixels_only=pixels_only)
@@ -118,10 +108,10 @@ def test_single_array_observation(self, pixels_only):
self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
if pixels_only:
- self.assertEqual(1, len(wrapped_observation_spec))
+ self.assertLen(wrapped_observation_spec, 1)
self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
else:
- self.assertEqual(2, len(wrapped_observation_spec))
+ self.assertLen(wrapped_observation_spec, 2)
self.assertEqual([pixels.STATE_KEY, pixel_key],
list(wrapped_observation_spec.keys()))
diff --git a/dm_control/third_party/README.md b/dm_control/third_party/README.md
new file mode 100644
index 00000000..f9f3b76d
--- /dev/null
+++ b/dm_control/third_party/README.md
@@ -0,0 +1,5 @@
+# Third-party assets
+
+This directory contains third-party assets which are licensed separately to the
+the rest of `dm_control`. Please see the LICENSE files within the individual
+subpackages for further details.
diff --git a/dm_control/third_party/ant/LICENSE b/dm_control/third_party/ant/LICENSE
new file mode 100644
index 00000000..891043f5
--- /dev/null
+++ b/dm_control/third_party/ant/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2020 Philipp Moritz, The dm_control Authors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/dm_control/third_party/ant/README.md b/dm_control/third_party/ant/README.md
new file mode 100644
index 00000000..a4b6aaac
--- /dev/null
+++ b/dm_control/third_party/ant/README.md
@@ -0,0 +1,7 @@
+# Ant MJCF model
+
+This model is intended for use as a walker in `dm_control.locomotion`. It is
+based on Philipp Moritz's original Ant model, which is available from
+https://github.com/pcmoritz/mujoco-control-ant. Substantial modifications
+have been made to the original model and therefore this model should be regarded
+as distinct for practical purposes.
diff --git a/dm_control/third_party/ant/ant.xml b/dm_control/third_party/ant/ant.xml
new file mode 100644
index 00000000..979134e2
--- /dev/null
+++ b/dm_control/third_party/ant/ant.xml
@@ -0,0 +1,159 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/third_party/kinova/LICENSE b/dm_control/third_party/kinova/LICENSE
new file mode 100644
index 00000000..3265369f
--- /dev/null
+++ b/dm_control/third_party/kinova/LICENSE
@@ -0,0 +1,28 @@
+Original work: Copyright (c) 2017, Kinova Robotics inc.
+Modified work: Copyright (c) 2018, The dm_control Authors.
+
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+ * Neither the name of the copyright holder nor the names of its contributors
+ may be used to endorse or promote products derived from this software
+ without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/dm_control/third_party/kinova/common.xml b/dm_control/third_party/kinova/common.xml
new file mode 100644
index 00000000..64ab70b4
--- /dev/null
+++ b/dm_control/third_party/kinova/common.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/third_party/kinova/jaco_arm.xml b/dm_control/third_party/kinova/jaco_arm.xml
new file mode 100644
index 00000000..609b087e
--- /dev/null
+++ b/dm_control/third_party/kinova/jaco_arm.xml
@@ -0,0 +1,76 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/third_party/kinova/jaco_hand.xml b/dm_control/third_party/kinova/jaco_hand.xml
new file mode 100644
index 00000000..9c6087aa
--- /dev/null
+++ b/dm_control/third_party/kinova/jaco_hand.xml
@@ -0,0 +1,63 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/third_party/kinova/meshes/arm.stl b/dm_control/third_party/kinova/meshes/arm.stl
new file mode 100644
index 00000000..fbb8668e
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/arm.stl differ
diff --git a/dm_control/third_party/kinova/meshes/base.stl b/dm_control/third_party/kinova/meshes/base.stl
new file mode 100644
index 00000000..3eb0b4ba
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/base.stl differ
diff --git a/dm_control/third_party/kinova/meshes/finger_distal.stl b/dm_control/third_party/kinova/meshes/finger_distal.stl
new file mode 100644
index 00000000..1c8ff85d
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/finger_distal.stl differ
diff --git a/dm_control/third_party/kinova/meshes/finger_proximal.stl b/dm_control/third_party/kinova/meshes/finger_proximal.stl
new file mode 100644
index 00000000..43ca39f0
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/finger_proximal.stl differ
diff --git a/dm_control/third_party/kinova/meshes/forearm.stl b/dm_control/third_party/kinova/meshes/forearm.stl
new file mode 100644
index 00000000..e66b8b89
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/forearm.stl differ
diff --git a/dm_control/third_party/kinova/meshes/hand_3finger.stl b/dm_control/third_party/kinova/meshes/hand_3finger.stl
new file mode 100644
index 00000000..9ed2d59c
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/hand_3finger.stl differ
diff --git a/dm_control/third_party/kinova/meshes/hand_3finger_insert.stl b/dm_control/third_party/kinova/meshes/hand_3finger_insert.stl
new file mode 100644
index 00000000..5ff7030b
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/hand_3finger_insert.stl differ
diff --git a/dm_control/third_party/kinova/meshes/hand_3finger_main.stl b/dm_control/third_party/kinova/meshes/hand_3finger_main.stl
new file mode 100644
index 00000000..c9fd13d4
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/hand_3finger_main.stl differ
diff --git a/dm_control/third_party/kinova/meshes/shoulder.stl b/dm_control/third_party/kinova/meshes/shoulder.stl
new file mode 100644
index 00000000..cf91ea49
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/shoulder.stl differ
diff --git a/dm_control/third_party/kinova/meshes/wrist.stl b/dm_control/third_party/kinova/meshes/wrist.stl
new file mode 100644
index 00000000..8b2f6132
Binary files /dev/null and b/dm_control/third_party/kinova/meshes/wrist.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/arm.stl b/dm_control/third_party/kinova/meshes_decimated/arm.stl
new file mode 100644
index 00000000..f1937a44
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/arm.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/base.stl b/dm_control/third_party/kinova/meshes_decimated/base.stl
new file mode 100644
index 00000000..623ce51d
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/base.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/finger_distal.stl b/dm_control/third_party/kinova/meshes_decimated/finger_distal.stl
new file mode 100644
index 00000000..2be94ce9
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/finger_distal.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/finger_proximal.stl b/dm_control/third_party/kinova/meshes_decimated/finger_proximal.stl
new file mode 100644
index 00000000..935d502e
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/finger_proximal.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/forearm.stl b/dm_control/third_party/kinova/meshes_decimated/forearm.stl
new file mode 100644
index 00000000..cd050234
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/forearm.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/hand_3finger.stl b/dm_control/third_party/kinova/meshes_decimated/hand_3finger.stl
new file mode 100644
index 00000000..7e6920f7
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/hand_3finger.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/hand_3finger_insert.stl b/dm_control/third_party/kinova/meshes_decimated/hand_3finger_insert.stl
new file mode 100644
index 00000000..79783686
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/hand_3finger_insert.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/hand_3finger_main.stl b/dm_control/third_party/kinova/meshes_decimated/hand_3finger_main.stl
new file mode 100644
index 00000000..a85cd56a
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/hand_3finger_main.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/shoulder.stl b/dm_control/third_party/kinova/meshes_decimated/shoulder.stl
new file mode 100644
index 00000000..072c0368
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/shoulder.stl differ
diff --git a/dm_control/third_party/kinova/meshes_decimated/wrist.stl b/dm_control/third_party/kinova/meshes_decimated/wrist.stl
new file mode 100644
index 00000000..65512a9c
Binary files /dev/null and b/dm_control/third_party/kinova/meshes_decimated/wrist.stl differ
diff --git a/dm_control/utils/README.md b/dm_control/utils/README.md
deleted file mode 100644
index d605e4c8..00000000
--- a/dm_control/utils/README.md
+++ /dev/null
@@ -1,6 +0,0 @@
-# Tolerance
-
-`tolerance()` is a soft indicator function evaluating whether a number is within
-bounds.
-
-See [package documentation](/third_party/py/dm_control/utils).
diff --git a/dm_control/utils/containers.py b/dm_control/utils/containers.py
index 362362cf..b08a1762 100644
--- a/dm_control/utils/containers.py
+++ b/dm_control/utils/containers.py
@@ -15,79 +15,14 @@
"""Container classes used in control domains."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import collections
-
-class Tasks(collections.Mapping):
- """Maps task names to their corresponding factory functions.
-
- To store a function in a `Tasks` container, we can use its `.add` decorator:
-
- ```python
- tasks = Tasks()
-
- @tasks.add
- def example_task():
- ...
- return environment
-
- environment_factory = tasks['example_task']
- ```
-
- To add tasks that are procedurally generated, we can pass the optional `name`
- argument to the `.add` method:
-
- ```python
- for difficulty in ('easy', 'normal', 'hard'):
- func = my_task_generator(difficulty)
- tasks.add(func, name='my_task_{}'.format(difficulty))
- ```
-
- """
-
- def __init__(self):
- self._tasks = collections.OrderedDict()
-
- def add(self, factory_func, name=None):
- """Decorator that adds a factory function to the container.
-
- Args:
- factory_func: A function that returns a `ControlEnvironment` instance.
- name: Optional task name. If unspecified, `factory_func.name` is used.
-
- Returns:
- The same function.
-
- Raises:
- ValueError: if a function with the same name already exists within the
- container.
- """
- if name is None:
- name = factory_func.__name__
- if name in self:
- raise ValueError("Function named {!r} already exists in the container."
- "".format(name))
- self._tasks[name] = factory_func
- return factory_func
-
- def __getitem__(self, k):
- return self._tasks[k]
-
- def __iter__(self):
- return iter(self._tasks)
-
- def __len__(self):
- return len(self._tasks)
-
- def __repr__(self):
- return "{}({})".format(self.__class__.__name__, str(self._tasks))
+_NAME_ALREADY_EXISTS = (
+ "A function named {name!r} already exists in the container and "
+ "`allow_overriding_keys` is False.")
-class TaggedTasks(collections.Mapping):
+class TaggedTasks(collections.abc.Mapping):
"""Maps task names to their corresponding factory functions with tags.
To store a function in a `TaggedTasks` container, we can use its `.add`
@@ -108,9 +43,17 @@ def example_task():
```
"""
- def __init__(self):
+ def __init__(self, allow_overriding_keys=False):
+ """Initializes a new `TaggedTasks` container.
+
+ Args:
+ allow_overriding_keys: Boolean, whether `add` can override existing keys
+ within the container. If False (default), calling `add` multiple times
+ with the same function name will result in a `ValueError`.
+ """
self._tasks = collections.OrderedDict()
self._tags = collections.defaultdict(dict)
+ self.allow_overriding_keys = allow_overriding_keys
def add(self, *tags):
"""Decorator that adds a factory function to the container with tags.
@@ -123,25 +66,37 @@ def add(self, *tags):
Raises:
ValueError: if a function with the same name already exists within the
- container.
+ container and `allow_overriding_keys` is False.
"""
def wrap(factory_func):
name = factory_func.__name__
- if name in self:
- raise ValueError("Function named {!r} already exists in the container."
- "".format(name))
+ if name in self and not self.allow_overriding_keys:
+ raise ValueError(_NAME_ALREADY_EXISTS.format(name=name))
self._tasks[name] = factory_func
for tag in tags:
self._tags[tag][name] = factory_func
return factory_func
return wrap
- def tagged(self, tag):
- """Returns a (possibly empty) dict of all items that match the given tag."""
- if tag not in self._tags:
+ def tagged(self, *tags):
+ """Returns a (possibly empty) dict of functions matching all the given tags.
+
+ Args:
+ *tags: Strings specifying tags to query by.
+
+ Returns:
+ A dict of `{name: function}` containing all the functions that are tagged
+ by all of the strings in `tags`.
+ """
+ if not tags:
+ return {}
+ tags = set(tags)
+ if not tags.issubset(self._tags.keys()):
return {}
- else:
- return self._tags[tag]
+ names = self._tags[tags.pop()].keys()
+ while tags:
+ names &= self._tags[tags.pop()].keys()
+ return {name: self._tasks[name] for name in names}
def tags(self):
"""Returns a list of all the tags in this container."""
diff --git a/dm_control/utils/containers_test.py b/dm_control/utils/containers_test.py
index 151ada40..aad36265 100644
--- a/dm_control/utils/containers_test.py
+++ b/dm_control/utils/containers_test.py
@@ -13,61 +13,14 @@
# limitations under the License.
# ============================================================================
-"""Tests for control.utils.containers."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
+"""Tests for dm_control.utils.containers."""
from absl.testing import absltest
-
+from absl.testing import parameterized
from dm_control.utils import containers
-class TaskTest(absltest.TestCase):
-
- def test_factory_registered(self):
- tasks = containers.Tasks()
-
- @tasks.add
- def test_factory1(): # pylint: disable=unused-variable
- return 'executed 1'
-
- @tasks.add
- def test_factory2(): # pylint: disable=unused-variable
- return 'executed 2'
-
- with self.assertRaises(ValueError):
- @tasks.add
- def test_factory1(): # pylint: disable=function-redefined
- return
-
- self.assertEqual(2, len(tasks))
- self.assertEqual(set(['test_factory1', 'test_factory2']),
- set(tasks.keys()))
- self.assertEqual('executed 1', tasks['test_factory1']())
- self.assertEqual('executed 2', tasks['test_factory2']())
-
- def test_procedural_names(self):
- tasks = containers.Tasks()
- names = set(('easy', 'normal', 'hard'))
- for name in names:
- tasks.add(lambda: None, name=name)
- self.assertEqual(len(names), len(tasks))
- self.assertSetEqual(names, set(tasks.keys()))
-
- def test_iteration_order(self):
- tasks = containers.Tasks()
- expected_order = ['first', 'second', 'third', 'fourth']
- for name in expected_order:
- tasks.add(lambda: None, name=name)
- actual_order = list(tasks)
- self.assertEqual(expected_order, actual_order)
-
-
-class TaggedTaskTest(absltest.TestCase):
+class TaggedTaskTest(parameterized.TestCase):
def test_registration(self):
tasks = containers.TaggedTasks()
@@ -88,14 +41,14 @@ def test_factory3(): # pylint: disable=unused-variable
def test_factory4(): # pylint: disable=unused-variable
return 'executed 4'
- self.assertEqual(4, len(tasks))
+ self.assertLen(tasks, 4)
self.assertEqual(set(['basic', 'expert', 'stable', 'unstable']),
set(tasks.tags()))
- self.assertEqual(1, len(tasks.tagged('basic')))
- self.assertEqual(2, len(tasks.tagged('expert')))
- self.assertEqual(2, len(tasks.tagged('stable')))
- self.assertEqual(1, len(tasks.tagged('unstable')))
+ self.assertLen(tasks.tagged('basic'), 1)
+ self.assertLen(tasks.tagged('expert'), 2)
+ self.assertLen(tasks.tagged('stable'), 2)
+ self.assertLen(tasks.tagged('unstable'), 1)
self.assertEqual('executed 2', tasks['test_factory2']())
@@ -126,5 +79,47 @@ def fourth(): # pylint: disable=unused-variable
actual_order = list(tasks)
self.assertEqual(expected_order, actual_order)
+ def test_override_behavior(self):
+ tasks = containers.TaggedTasks(allow_overriding_keys=False)
+
+ @tasks.add()
+ def some_func():
+ pass
+
+ expected_message = containers._NAME_ALREADY_EXISTS.format(name='some_func')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected_message):
+ tasks.add()(some_func)
+
+ tasks.allow_overriding_keys = True
+ tasks.add()(some_func) # Override should now succeed.
+
+ @parameterized.parameters(
+ {'query': ['a'], 'expected_keys': frozenset(['f1', 'f2', 'f3'])},
+ {'query': ['b', 'c'], 'expected_keys': frozenset(['f2'])},
+ {'query': ['c'], 'expected_keys': frozenset(['f2', 'f3'])},
+ {'query': ['b', 'd'], 'expected_keys': frozenset()},
+ {'query': ['e'], 'expected_keys': frozenset()},
+ {'query': [], 'expected_keys': frozenset()})
+ def test_query_tag_intersection(self, query, expected_keys):
+ tasks = containers.TaggedTasks()
+
+ # pylint: disable=unused-variable
+ @tasks.add('a', 'b')
+ def f1():
+ pass
+
+ @tasks.add('a', 'b', 'c')
+ def f2():
+ pass
+
+ @tasks.add('a', 'c', 'd')
+ def f3():
+ pass
+ # pylint: enable=unused-variable
+
+ result = tasks.tagged(*query)
+ self.assertSetEqual(frozenset(result.keys()), expected_keys)
+
+
if __name__ == '__main__':
absltest.main()
diff --git a/dm_control/utils/corruptors.py b/dm_control/utils/corruptors.py
deleted file mode 100644
index c365e08e..00000000
--- a/dm_control/utils/corruptors.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Corruptors."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import collections
-import copy
-import functools
-
-# Internal dependencies.
-
-import numpy as np
-import six
-
-
-@six.add_metaclass(abc.ABCMeta)
-class CorruptorBase(object):
-
- @abc.abstractmethod
- def __call__(self, x):
- """Returns a corrupted version of the input x."""
-
- @abc.abstractmethod
- def reset(self):
- """Resets the internal state of the corruptor."""
-
-
-class Delay(CorruptorBase):
- """Applies a delay to the input."""
-
- def __init__(self, steps, padding=None):
- """Initialize an instance of `Delay`.
-
- Args:
- steps: An int, number of steps for the delay.
- padding: An optional numpy array or a function. The output in the
- first `steps`.
-
- Raises:
- ValueError: When `steps` <= 0.
- """
- if steps <= 0:
- raise ValueError('Delay steps should be greater than 0, %d found', steps)
- self._buffer = collections.deque(maxlen=steps + 1)
- self._padding = padding or (lambda x: np.zeros(x.shape))
-
- def __call__(self, x):
- """Returns the input to this function from `steps` calls ago."""
- self._buffer.append(copy.deepcopy(x))
- if len(self._buffer) == self._buffer.maxlen:
- return self._buffer.popleft()
- else:
- return self._padding(x) if callable(self._padding) else self._padding
-
- def reset(self):
- """Resets the buffer."""
- self._buffer.clear()
-
-
-class StatelessNoise(CorruptorBase):
- """Applies noise to an input without relying on any internal state."""
-
- def __init__(self, noise_function, **noise_parameters):
- """Initialize an instance of `StatelessNoise`.
-
- Args:
- noise_function: A function, adding noise to its input.
- **noise_parameters: Additional keyword arguments taken by the
- `noise_function`.
- """
- self._noise_function = functools.partial(noise_function, **noise_parameters)
-
- def __call__(self, x):
- """Returns the input to this function with noise added."""
- return self._noise_function(x)
-
- def reset(self):
- pass
-
-
-def gaussian_noise(x, std):
- """Adds gaussian noise to each dimension of x.
-
- Example of gaussian noise corruptor:
- ```python
- corruptor = StatelessNoise(noise_function=gaussian_noise,
- noise_parameter={'std': .1})
- ```
-
- Args:
- x: A numpy array, the input.
- std: A number, standard deviation of the gaussian noise.
-
- Returns:
- A numpy array with the same dimension as x, which adds a noise draw from a
- normal distribution to each dimension of x.
- """
- return x + np.random.standard_normal(x.shape) * std
diff --git a/dm_control/utils/corruptors_test.py b/dm_control/utils/corruptors_test.py
deleted file mode 100644
index a651f933..00000000
--- a/dm_control/utils/corruptors_test.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# Copyright 2017 The dm_control Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""Tests for corruptors."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
-from absl.testing import absltest
-from absl.testing import parameterized
-
-from dm_control.utils import corruptors
-
-import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-
-class DelayTest(absltest.TestCase):
-
- def setUp(self):
- self.n = 10
- self.delay = corruptors.Delay(steps=self.n)
- super(DelayTest, self).setUp()
-
- def testProcess(self):
- obs = np.array(range(2 * self.n))
- actual_obs = []
- for i in obs:
- actual_obs.append(self.delay(i))
- expected = np.hstack(([.0] * self.n, obs[:self.n]))
- np.testing.assert_array_equal(expected, actual_obs)
-
- actual_obs = []
- for i in obs:
- actual_obs.append(self.delay(i))
- expected = np.hstack((obs[self.n:], obs[:self.n]))
- np.testing.assert_array_equal(expected, actual_obs)
-
- def testReset(self):
- obs = np.array(range(2 * self.n))
- for _ in xrange(2):
- actual_obs = []
- for i in obs:
- actual_obs.append(self.delay(i))
- self.delay.reset()
-
- expected = np.hstack(([.0] * self.n, obs[:self.n]))
- np.testing.assert_array_equal(expected, actual_obs)
-
-
-class StatelessNoiseTest(absltest.TestCase):
-
- def testProcess(self):
- c = corruptors.StatelessNoise(noise_function=corruptors.gaussian_noise,
- std=1e-3)
- x = np.array([.0] * 3)
- y = np.array([.0] * 3)
- n = 1e3
- for _ in xrange(int(n)):
- y += c(x)
- y /= n
- np.testing.assert_allclose(x, y, atol=1e-4)
-
-
-class NoiseFunctionTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- ('1D', np.array([3., 4.]), .1),
- ('2D', np.array([[.0, .1], [1., 2.]]), .4)
- )
- def testGaussianNoise_Shape(self, x, std):
- noisy_x = corruptors.gaussian_noise(x, std)
- self.assertEqual(x.shape, noisy_x.shape)
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/dm_control/utils/inverse_kinematics.py b/dm_control/utils/inverse_kinematics.py
new file mode 100644
index 00000000..623d2ea9
--- /dev/null
+++ b/dm_control/utils/inverse_kinematics.py
@@ -0,0 +1,260 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Functions for computing inverse kinematics on MuJoCo models."""
+
+import collections
+
+from absl import logging
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+
+_INVALID_JOINT_NAMES_TYPE = (
+ '`joint_names` must be either None, a list, a tuple, or a numpy array; '
+ 'got {}.')
+_REQUIRE_TARGET_POS_OR_QUAT = (
+ 'At least one of `target_pos` or `target_quat` must be specified.')
+
+IKResult = collections.namedtuple(
+ 'IKResult', ['qpos', 'err_norm', 'steps', 'success'])
+
+
+def qpos_from_site_pose(physics,
+ site_name,
+ target_pos=None,
+ target_quat=None,
+ joint_names=None,
+ tol=1e-14,
+ rot_weight=1.0,
+ regularization_threshold=0.1,
+ regularization_strength=3e-2,
+ max_update_norm=2.0,
+ progress_thresh=20.0,
+ max_steps=100,
+ inplace=False):
+ """Find joint positions that satisfy a target site position and/or rotation.
+
+ Args:
+ physics: A `mujoco.Physics` instance.
+ site_name: A string specifying the name of the target site.
+ target_pos: A (3,) numpy array specifying the desired Cartesian position of
+ the site, or None if the position should be unconstrained (default).
+ One or both of `target_pos` or `target_quat` must be specified.
+ target_quat: A (4,) numpy array specifying the desired orientation of the
+ site as a quaternion, or None if the orientation should be unconstrained
+ (default). One or both of `target_pos` or `target_quat` must be specified.
+ joint_names: (optional) A list, tuple or numpy array specifying the names of
+ one or more joints that can be manipulated in order to achieve the target
+ site pose. If None (default), all joints may be manipulated.
+ tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
+ in the stopping criterion).
+ rot_weight: (optional) Determines the weight given to rotational error
+ relative to translational error.
+ regularization_threshold: (optional) L2 regularization will be used when
+ inverting the Jacobian whilst `err_norm` is greater than this value.
+ regularization_strength: (optional) Coefficient of the quadratic penalty
+ on joint movements.
+ max_update_norm: (optional) The maximum L2 norm of the update applied to
+ the joint positions on each iteration. The update vector will be scaled
+ such that its magnitude never exceeds this value.
+ progress_thresh: (optional) If `err_norm` divided by the magnitude of the
+ joint position update is greater than this value then the optimization
+ will terminate prematurely. This is a useful heuristic to avoid getting
+ stuck in local minima.
+ max_steps: (optional) The maximum number of iterations to perform.
+ inplace: (optional) If True, `physics.data` will be modified in place.
+ Default value is False, i.e. a copy of `physics.data` will be made.
+
+ Returns:
+ An `IKResult` namedtuple with the following fields:
+ qpos: An (nq,) numpy array of joint positions.
+ err_norm: A float, the weighted sum of L2 norms for the residual
+ translational and rotational errors.
+ steps: An int, the number of iterations that were performed.
+ success: Boolean, True if we converged on a solution within `max_steps`,
+ False otherwise.
+
+ Raises:
+ ValueError: If both `target_pos` and `target_quat` are None, or if
+ `joint_names` has an invalid type.
+ """
+
+ dtype = physics.data.qpos.dtype
+
+ if target_pos is not None and target_quat is not None:
+ jac = np.empty((6, physics.model.nv), dtype=dtype)
+ err = np.empty(6, dtype=dtype)
+ jac_pos, jac_rot = jac[:3], jac[3:]
+ err_pos, err_rot = err[:3], err[3:]
+ else:
+ jac = np.empty((3, physics.model.nv), dtype=dtype)
+ err = np.empty(3, dtype=dtype)
+ if target_pos is not None:
+ jac_pos, jac_rot = jac, None
+ err_pos, err_rot = err, None
+ elif target_quat is not None:
+ jac_pos, jac_rot = None, jac
+ err_pos, err_rot = None, err
+ else:
+ raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
+
+ update_nv = np.zeros(physics.model.nv, dtype=dtype)
+
+ if target_quat is not None:
+ site_xquat = np.empty(4, dtype=dtype)
+ neg_site_xquat = np.empty(4, dtype=dtype)
+ err_rot_quat = np.empty(4, dtype=dtype)
+
+ if not inplace:
+ physics = physics.copy(share_model=True)
+
+ # Ensure that the Cartesian position of the site is up to date.
+ mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
+
+ # Convert site name to index.
+ site_id = physics.model.name2id(site_name, 'site')
+
+ # These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
+ # update them in place, so we can avoid indexing overhead in the main loop.
+ site_xpos = physics.named.data.site_xpos[site_name]
+ site_xmat = physics.named.data.site_xmat[site_name]
+
+ # This is an index into the rows of `update` and the columns of `jac`
+ # that selects DOFs associated with joints that we are allowed to manipulate.
+ if joint_names is None:
+ dof_indices = slice(None) # Update all DOFs.
+ elif isinstance(joint_names, (list, np.ndarray, tuple)):
+ if isinstance(joint_names, tuple):
+ joint_names = list(joint_names)
+ # Find the indices of the DOFs belonging to each named joint. Note that
+ # these are not necessarily the same as the joint IDs, since a single joint
+ # may have >1 DOF (e.g. ball joints).
+ indexer = physics.named.model.dof_jntid.axes.row
+ # `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
+ # indexer to map each joint name to the indices of its corresponding DOFs.
+ dof_indices = indexer.convert_key_item(joint_names)
+ else:
+ raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
+
+ steps = 0
+ success = False
+
+ for steps in range(max_steps):
+
+ err_norm = 0.0
+
+ if target_pos is not None:
+ # Translational error.
+ err_pos[:] = target_pos - site_xpos
+ err_norm += np.linalg.norm(err_pos)
+ if target_quat is not None:
+ # Rotational error.
+ mjlib.mju_mat2Quat(site_xquat, site_xmat)
+ mjlib.mju_negQuat(neg_site_xquat, site_xquat)
+ mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
+ mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
+ err_norm += np.linalg.norm(err_rot) * rot_weight
+
+ if err_norm < tol:
+ logging.debug('Converged after %i steps: err_norm=%3g', steps, err_norm)
+ success = True
+ break
+ else:
+ # TODO(b/112141670): Generalize this to other entities besides sites.
+ mjlib.mj_jacSite(
+ physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id)
+ jac_joints = jac[:, dof_indices]
+
+ # TODO(b/112141592): This does not take joint limits into consideration.
+ reg_strength = (
+ regularization_strength if err_norm > regularization_threshold
+ else 0.0)
+ update_joints = nullspace_method(
+ jac_joints, err, regularization_strength=reg_strength)
+
+ update_norm = np.linalg.norm(update_joints)
+
+ # Check whether we are still making enough progress, and halt if not.
+ progress_criterion = err_norm / update_norm
+ if progress_criterion > progress_thresh:
+ logging.debug('Step %2i: err_norm / update_norm (%3g) > '
+ 'tolerance (%3g). Halting due to insufficient progress',
+ steps, progress_criterion, progress_thresh)
+ break
+
+ if update_norm > max_update_norm:
+ update_joints *= max_update_norm / update_norm
+
+ # Write the entries for the specified joints into the full `update_nv`
+ # vector.
+ update_nv[dof_indices] = update_joints
+
+ # Update `physics.qpos`, taking quaternions into account.
+ mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
+
+ # Compute the new Cartesian position of the site.
+ mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
+
+ logging.debug('Step %2i: err_norm=%-10.3g update_norm=%-10.3g',
+ steps, err_norm, update_norm)
+
+ if not success and steps == max_steps - 1:
+ logging.warning('Failed to converge after %i steps: err_norm=%3g',
+ steps, err_norm)
+
+ if not inplace:
+ # Our temporary copy of physics.data is about to go out of scope, and when
+ # it does the underlying mjData pointer will be freed and physics.data.qpos
+ # will be a view onto a block of deallocated memory. We therefore need to
+ # make a copy of physics.data.qpos while physics.data is still alive.
+ qpos = physics.data.qpos.copy()
+ else:
+ # If we're modifying physics.data in place then it's fine to return a view.
+ qpos = physics.data.qpos
+
+ return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
+
+
+def nullspace_method(jac_joints, delta, regularization_strength=0.0):
+ """Calculates the joint velocities to achieve a specified end effector delta.
+
+ Args:
+ jac_joints: The Jacobian of the end effector with respect to the joints. A
+ numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
+ and `nv` is the number of degrees of freedom.
+ delta: The desired end-effector delta. A numpy array of shape `(3,)` or
+ `(6,)` containing either position deltas, rotation deltas, or both.
+ regularization_strength: (optional) Coefficient of the quadratic penalty
+ on joint movements. Default is zero, i.e. no regularization.
+
+ Returns:
+ An `(nv,)` numpy array of joint velocities.
+
+ Reference:
+ Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
+ transpose, pseudoinverse and damped least squares methods.
+ https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
+ """
+ hess_approx = jac_joints.T.dot(jac_joints)
+ joint_delta = jac_joints.T.dot(delta)
+ if regularization_strength > 0:
+ # L2 regularization
+ hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
+ return np.linalg.solve(hess_approx, joint_delta)
+ else:
+ return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]
diff --git a/dm_control/utils/inverse_kinematics_test.py b/dm_control/utils/inverse_kinematics_test.py
new file mode 100644
index 00000000..493a1dbe
--- /dev/null
+++ b/dm_control/utils/inverse_kinematics_test.py
@@ -0,0 +1,215 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for inverse_kinematics."""
+
+import itertools
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mujoco
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.utils import inverse_kinematics as ik
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+_ARM_XML = assets.get_contents('arm.xml')
+_MODEL_WITH_BALL_JOINTS_XML = assets.get_contents('model_with_ball_joints.xml')
+
+_SITE_NAME = 'gripsite'
+_JOINTS = ['joint_1', 'joint_2', 'joint_3', 'joint_4', 'joint_5', 'joint_6']
+_TOL = 1.2e-14
+_MAX_STEPS = 100
+_MAX_RESETS = 10
+
+_TARGETS = [
+ # target_pos # target_quat
+ (np.array([0., 0., 0.3]), np.array([0., 1., 0., 1.])),
+ (np.array([-0.5, 0., 0.5]), None),
+ (np.array([0., 0., 0.8]), np.array([0., 1., 0., 1.])),
+ (np.array([0., 0., 0.8]), None),
+ (np.array([0., -0.1, 0.5]), None),
+ (np.array([0., -0.1, 0.5]), np.array([1., 1., 0., 0.])),
+ (np.array([0.5, 0., 0.5]), None),
+ (np.array([0.4, 0.1, 0.5]), None),
+ (np.array([0.4, 0.1, 0.5]), np.array([1., 0., 0., 0.])),
+ (np.array([0., 0., 0.3]), None),
+ (np.array([0., 0.5, -0.2]), None),
+ (np.array([0.5, 0., 0.3]), np.array([1., 0., 0., 1.])),
+ (None, np.array([1., 0., 0., 1.])),
+ (None, np.array([0., 1., 0., 1.])),
+]
+_INPLACE = [False, True]
+
+
+class _ResetArm:
+
+ def __init__(self, seed=None):
+ self._rng = np.random.RandomState(seed)
+ self._lower = None
+ self._upper = None
+
+ def _cache_bounds(self, physics):
+ self._lower, self._upper = physics.named.model.jnt_range[_JOINTS].T
+ limited = physics.named.model.jnt_limited[_JOINTS].astype(bool)
+ # Positions for hinge joints without limits are sampled between 0 and 2pi
+ self._lower[~limited] = 0
+ self._upper[~limited] = 2 * np.pi
+
+ def __call__(self, physics):
+ if self._lower is None:
+ self._cache_bounds(physics)
+ # NB: This won't work for joints with > 1 DOF
+ new_qpos = self._rng.uniform(self._lower, self._upper)
+ physics.named.data.qpos[_JOINTS] = new_qpos
+
+
+class InverseKinematicsTest(parameterized.TestCase):
+
+ @parameterized.parameters(itertools.product(_TARGETS, _INPLACE))
+ def testQposFromSitePose(self, target, inplace):
+ physics = mujoco.Physics.from_xml_string(_ARM_XML)
+ target_pos, target_quat = target
+ count = 0
+ physics2 = physics.copy(share_model=True)
+ resetter = _ResetArm(seed=0)
+ while True:
+ result = ik.qpos_from_site_pose(
+ physics=physics2,
+ site_name=_SITE_NAME,
+ target_pos=target_pos,
+ target_quat=target_quat,
+ joint_names=_JOINTS,
+ tol=_TOL,
+ max_steps=_MAX_STEPS,
+ inplace=inplace,
+ )
+ if result.success:
+ break
+ elif count < _MAX_RESETS:
+ resetter(physics2)
+ count += 1
+ else:
+ raise RuntimeError(
+ 'Failed to find a solution within %i attempts.' % _MAX_RESETS)
+
+ self.assertLessEqual(result.steps, _MAX_STEPS)
+ self.assertLessEqual(result.err_norm, _TOL)
+ physics.data.qpos[:] = result.qpos
+ mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
+ if target_pos is not None:
+ pos = physics.named.data.site_xpos[_SITE_NAME]
+ np.testing.assert_array_almost_equal(pos, target_pos)
+ if target_quat is not None:
+ xmat = physics.named.data.site_xmat[_SITE_NAME]
+ quat = np.empty_like(target_quat)
+ mjlib.mju_mat2Quat(quat, xmat)
+ quat /= np.ptp(quat) # Normalize xquat so that its max-min range is 1
+ np.testing.assert_array_almost_equal(quat, target_quat)
+
+ def testNamedJointsWithMultipleDOFs(self):
+ """Regression test for b/77506142."""
+ physics = mujoco.Physics.from_xml_string(_MODEL_WITH_BALL_JOINTS_XML)
+ site_name = 'gripsite'
+ joint_names = ['joint_1', 'joint_2']
+
+ # This target position can only be achieved by rotating both ball joints. If
+ # DOFs are incorrectly indexed by joint index then only the first two DOFs
+ # in the first ball joint can change, and IK will fail to find a solution.
+ target_pos = (0.05, 0.05, 0)
+ result = ik.qpos_from_site_pose(
+ physics=physics,
+ site_name=site_name,
+ target_pos=target_pos,
+ joint_names=joint_names,
+ tol=_TOL,
+ max_steps=_MAX_STEPS,
+ inplace=True)
+
+ self.assertTrue(result.success)
+ self.assertLessEqual(result.steps, _MAX_STEPS)
+ self.assertLessEqual(result.err_norm, _TOL)
+ pos = physics.named.data.site_xpos[site_name]
+ np.testing.assert_array_almost_equal(pos, target_pos)
+
+ # IK should fail to converge if only the first joint is allowed to move.
+ physics.reset()
+ result = ik.qpos_from_site_pose(
+ physics=physics,
+ site_name=site_name,
+ target_pos=target_pos,
+ joint_names=joint_names[:1],
+ tol=_TOL,
+ max_steps=_MAX_STEPS,
+ inplace=True)
+ self.assertFalse(result.success)
+
+ @parameterized.named_parameters(
+ ('None', None),
+ ('list', ['joint_1', 'joint_2']),
+ ('tuple', ('joint_1', 'joint_2')),
+ ('numpy.array', np.array(['joint_1', 'joint_2'])))
+ def testAllowedJointNameTypes(self, joint_names):
+ """Test allowed types for joint_names parameter."""
+ physics = mujoco.Physics.from_xml_string(_ARM_XML)
+ site_name = 'gripsite'
+ target_pos = (0.05, 0.05, 0)
+ ik.qpos_from_site_pose(
+ physics=physics,
+ site_name=site_name,
+ target_pos=target_pos,
+ joint_names=joint_names,
+ tol=_TOL,
+ max_steps=_MAX_STEPS,
+ inplace=True)
+
+ @parameterized.named_parameters(
+ ('int', 1),
+ ('dict', {'joint_1': 1, 'joint_2': 2}),
+ ('function', lambda x: x),
+ )
+ def testDisallowedJointNameTypes(self, joint_names):
+ physics = mujoco.Physics.from_xml_string(_ARM_XML)
+ site_name = 'gripsite'
+ target_pos = (0.05, 0.05, 0)
+
+ expected_message = ik._INVALID_JOINT_NAMES_TYPE.format(type(joint_names))
+
+ with self.assertRaisesWithLiteralMatch(ValueError, expected_message):
+ ik.qpos_from_site_pose(
+ physics=physics,
+ site_name=site_name,
+ target_pos=target_pos,
+ joint_names=joint_names,
+ tol=_TOL,
+ max_steps=_MAX_STEPS,
+ inplace=True)
+
+ def testNoTargetPosOrQuat(self):
+ physics = mujoco.Physics.from_xml_string(_ARM_XML)
+ site_name = 'gripsite'
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, ik._REQUIRE_TARGET_POS_OR_QUAT):
+ ik.qpos_from_site_pose(
+ physics=physics,
+ site_name=site_name,
+ tol=_TOL,
+ max_steps=_MAX_STEPS,
+ inplace=True)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/utils/resources.py b/dm_control/utils/io.py
similarity index 80%
rename from dm_control/utils/resources.py
rename to dm_control/utils/io.py
index b5976fd1..8ed9b94f 100644
--- a/dm_control/utils/resources.py
+++ b/dm_control/utils/io.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The dm_control Authors.
+# Copyright 2017-2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+# ============================================================================
+
"""IO functions."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+import os
def GetResource(name, mode='rb'):
@@ -27,4 +27,9 @@ def GetResourceFilename(name, mode='rb'):
del mode # Unused.
return name
+
+def WalkResources(path):
+ return os.walk(path)
+
+
GetResourceAsFile = open # pylint: disable=invalid-name
diff --git a/dm_control/utils/rewards.py b/dm_control/utils/rewards.py
index 2e2fba55..8f25b5f5 100644
--- a/dm_control/utils/rewards.py
+++ b/dm_control/utils/rewards.py
@@ -15,11 +15,7 @@
"""Soft indicator function evaluating whether a number is within bounds."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
+import warnings
import numpy as np
# The value returned by tolerance() at `margin` distance from `bounds` interval.
@@ -63,10 +59,18 @@ def _sigmoids(x, value_at_1, sigmoid):
scale = np.sqrt(1/value_at_1 - 1)
return 1 / ((x*scale)**2 + 1)
+ elif sigmoid == 'reciprocal':
+ scale = 1/value_at_1 - 1
+ return 1 / (abs(x)*scale + 1)
+
elif sigmoid == 'cosine':
scale = np.arccos(2*value_at_1 - 1) / np.pi
scaled_x = x*scale
- return np.where(abs(scaled_x) < 1, (1 + np.cos(np.pi*scaled_x))/2, 0.0)
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ action='ignore', message='invalid value encountered in cos')
+ cos_pi_scaled_x = np.cos(np.pi*scaled_x)
+ return np.where(abs(scaled_x) < 1, (1 + cos_pi_scaled_x)/2, 0.0)
elif sigmoid == 'linear':
scale = 1-value_at_1
diff --git a/dm_control/utils/rewards_test.py b/dm_control/utils/rewards_test.py
index a3428ca2..5fa0b667 100644
--- a/dm_control/utils/rewards_test.py
+++ b/dm_control/utils/rewards_test.py
@@ -15,17 +15,9 @@
"""Tests for dm_control.utils.rewards."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
-
from absl.testing import absltest
from absl.testing import parameterized
-
from dm_control.utils import rewards
-
import numpy as np
@@ -45,7 +37,7 @@ def test_tolerance_sigmoid_parameterisation(self, margin, value_at_margin):
@parameterized.parameters(("gaussian",), ("hyperbolic",), ("long_tail",),
("cosine",), ("tanh_squared",), ("linear",),
- ("quadratic"))
+ ("quadratic",), ("reciprocal",))
def test_tolerance_sigmoids(self, sigmoid):
margins = [0.01, 1.0, 100, 10000]
values_at_margin = [0.1, 0.5, 0.9]
diff --git a/dm_control/utils/transformations.py b/dm_control/utils/transformations.py
new file mode 100644
index 00000000..74289c35
--- /dev/null
+++ b/dm_control/utils/transformations.py
@@ -0,0 +1,659 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Rigid-body transformations including velocities and static forces."""
+
+
+from absl import logging
+import numpy as np
+
+# Constants used to determine when a rotation is close to a pole.
+_POLE_LIMIT = (1.0 - 1e-6)
+_TOL = 1e-10
+
+
+def _clip_within_precision(number, low, high, precision=_TOL):
+ """Clips input to provided range, checking precision.
+
+ Args:
+ number: (float) number to be clipped.
+ low: (float) lower bound.
+ high: (float) upper bound.
+ precision: (float) tolerance.
+
+ Returns:
+ Input clipped to given range.
+
+ Raises:
+ ValueError: If number is outside given range by more than given precision.
+ """
+ if (number < low - precision).any() or (number > high + precision).any():
+ raise ValueError(
+ 'Input {:.12f} not inside range [{:.12f}, {:.12f}] with precision {}'.
+ format(number, low, high, precision))
+
+ return np.clip(number, low, high)
+
+
+def _batch_mm(m1, m2):
+ """Batch matrix multiply.
+
+ Args:
+ m1: input lhs matrix with shape (batch, n, m).
+ m2: input rhs matrix with shape (batch, m, o).
+
+ Returns:
+ product matrix with shape (batch, n, o).
+ """
+ return np.einsum('bij,bjk->bik', m1, m2)
+
+
+def _rmat_to_euler_xyz(rmat):
+ """Converts a 3x3 rotation matrix to XYZ euler angles."""
+ # | r00 r01 r02 | | cy*cz -cy*sz sy |
+ # | r10 r11 r12 | = | cz*sx*sy+cx*sz cx*cz-sx*sy*sz -cy*sx |
+ # | r20 r21 r22 | | -cx*cz*sy+sx*sz cz*sx+cx*sy*sz cx*cy |
+ if rmat[0, 2] > _POLE_LIMIT:
+ logging.log_every_n_seconds(logging.WARNING, 'Angle at North Pole', 60)
+ z = np.arctan2(rmat[1, 0], rmat[1, 1])
+ y = np.pi/2
+ x = 0.0
+ return np.array([x, y, z])
+
+ if rmat[0, 2] < -_POLE_LIMIT:
+ logging.log_every_n_seconds(logging.WARNING, 'Angle at South Pole', 60)
+ z = np.arctan2(rmat[1, 0], rmat[1, 1])
+ y = -np.pi/2
+ x = 0.0
+ return np.array([x, y, z])
+
+ z = -np.arctan2(rmat[0, 1], rmat[0, 0])
+ y = np.arcsin(rmat[0, 2])
+ x = -np.arctan2(rmat[1, 2], rmat[2, 2])
+
+ # order of return is the order of input
+ return np.array([x, y, z])
+
+
+def _rmat_to_euler_xyx(rmat):
+ """Converts a 3x3 rotation matrix to XYX euler angles."""
+ # | r00 r01 r02 | | cy sy*sx1 sy*cx1 |
+ # | r10 r11 r12 | = | sy*sx0 cx0*cx1-cy*sx0*sx1 -cy*cx1*sx0-cx0*sx1 |
+ # | r20 r21 r22 | | -sy*cx0 cx1*sx0+cy*cx0*sx1 cy*cx0*cx1-sx0*sx1 |
+
+ if rmat[0, 0] < 1.0:
+ if rmat[0, 0] > -1.0:
+ y = np.arccos(_clip_within_precision(rmat[0, 0], -1., 1.))
+ x0 = np.arctan2(rmat[1, 0], -rmat[2, 0])
+ x1 = np.arctan2(rmat[0, 1], rmat[0, 2])
+ return np.array([x0, y, x1])
+ else:
+ # Not a unique solution: x1_angle - x0_angle = atan2(-r12,r11)
+ y = np.pi
+ x0 = -np.arctan2(-rmat[1, 2], rmat[1, 1])
+ x1 = 0.0
+ return np.array([x0, y, x1])
+ else:
+ # Not a unique solution: x1_angle + x0_angle = atan2(-r12,r11)
+ y = 0.0
+ x0 = -np.arctan2(-rmat[1, 2], rmat[1, 1])
+ x1 = 0.0
+ return np.array([x0, y, x1])
+
+
+def _rmat_to_euler_zyx(rmat):
+ """Converts a 3x3 rotation matrix to ZYX euler angles."""
+ if rmat[2, 0] > _POLE_LIMIT:
+ logging.warning('Angle at North Pole')
+ x = np.arctan2(rmat[0, 1], rmat[0, 2])
+ y = -np.pi/2
+ z = 0.0
+ return np.array([z, y, x])
+
+ if rmat[2, 0] < -_POLE_LIMIT:
+ logging.warning('Angle at South Pole')
+ x = np.arctan2(rmat[0, 1], rmat[0, 2])
+ y = np.pi/2
+ z = 0.0
+ return np.array([z, y, x])
+
+ x = np.arctan2(rmat[2, 1], rmat[2, 2])
+ y = -np.arcsin(rmat[2, 0])
+ z = np.arctan2(rmat[1, 0], rmat[0, 0])
+
+ # order of return is the order of input
+ return np.array([z, y, x])
+
+
+def _rmat_to_euler_xzy(rmat):
+ """Converts a 3x3 rotation matrix to XZY euler angles."""
+ if rmat[0, 1] > _POLE_LIMIT:
+ logging.warning('Angle at North Pole')
+ y = np.arctan2(rmat[1, 2], rmat[1, 0])
+ z = -np.pi/2
+ x = 0.0
+ return np.array([x, z, y])
+
+ if rmat[0, 1] < -_POLE_LIMIT:
+ logging.warning('Angle at South Pole')
+ y = np.arctan2(rmat[1, 2], rmat[1, 0])
+ z = np.pi/2
+ x = 0.0
+ return np.array([x, z, y])
+
+ y = np.arctan2(rmat[0, 2], rmat[0, 0])
+ z = -np.arcsin(rmat[0, 1])
+ x = np.arctan2(rmat[2, 1], rmat[1, 1])
+
+ # order of return is the order of input
+ return np.array([x, z, y])
+
+
+def _rmat_to_euler_yzx(rmat):
+ """Converts a 3x3 rotation matrix to YZX euler angles."""
+ if rmat[1, 0] > _POLE_LIMIT:
+ logging.warning('Angle at North Pole')
+ x = -np.arctan2(rmat[0, 2], rmat[0, 1])
+ z = np.pi/2
+ y = 0.0
+ return np.array([y, z, x])
+
+ if rmat[1, 0] < -_POLE_LIMIT:
+ logging.warning('Angle at South Pole')
+ x = -np.arctan2(rmat[0, 2], rmat[0, 1])
+ z = -np.pi/2
+ y = 0.0
+ return np.array([y, z, x])
+
+ x = -np.arctan2(rmat[1, 2], rmat[1, 1])
+ z = np.arcsin(rmat[1, 0])
+ y = -np.arctan2(rmat[2, 0], rmat[0, 0])
+
+ # order of return is the order of input
+ return np.array([y, z, x])
+
+
+def _rmat_to_euler_zxy(rmat):
+ """Converts a 3x3 rotation matrix to ZXY euler angles."""
+ if rmat[2, 1] > _POLE_LIMIT:
+ logging.warning('Angle at North Pole')
+ y = np.arctan2(rmat[0, 2], rmat[0, 0])
+ x = np.pi/2
+ z = 0.0
+ return np.array([z, x, y])
+
+ if rmat[2, 1] < -_POLE_LIMIT:
+ logging.warning('Angle at South Pole')
+ y = np.arctan2(rmat[0, 2], rmat[0, 0])
+ x = -np.pi/2
+ z = 0.0
+ return np.array([z, x, y])
+
+ y = -np.arctan2(rmat[2, 0], rmat[2, 2])
+ x = np.arcsin(rmat[2, 1])
+ z = -np.arctan2(rmat[0, 1], rmat[1, 1])
+
+ # order of return is the order of input
+ return np.array([z, x, y])
+
+
+def _rmat_to_euler_yxz(rmat):
+ """Converts a 3x3 rotation matrix to YXZ euler angles."""
+ if rmat[1, 2] > _POLE_LIMIT:
+ logging.warning('Angle at North Pole')
+ z = -np.arctan2(rmat[0, 1], rmat[0, 0])
+ x = -np.pi/2
+ y = 0.0
+ return np.array([y, x, z])
+
+ if rmat[1, 2] < -_POLE_LIMIT:
+ logging.warning('Angle at South Pole')
+ z = -np.arctan2(rmat[0, 1], rmat[0, 0])
+ x = np.pi/2
+ y = 0.0
+ return np.array([y, x, z])
+
+ z = np.arctan2(rmat[1, 0], rmat[1, 1])
+ x = -np.arcsin(rmat[1, 2])
+ y = np.arctan2(rmat[0, 2], rmat[2, 2])
+
+ # order of return is the order of input
+ return np.array([y, x, z])
+
+
+def _axis_rotation(theta, full):
+ """Returns the theta dim, cos and sin, and blank matrix for axis rotation."""
+ n = 1 if np.isscalar(theta) else len(theta)
+ ct = np.cos(theta)
+ st = np.sin(theta)
+
+ if full:
+ rmat = np.zeros((n, 4, 4))
+ rmat[:, 3, 3] = 1.
+ else:
+ rmat = np.zeros((n, 3, 3))
+
+ return n, ct, st, rmat
+
+# map from full rotation orderings to euler conversion functions
+_eulermap = {
+ 'XYZ': _rmat_to_euler_xyz,
+ 'XYX': _rmat_to_euler_xyx,
+ 'ZYX': _rmat_to_euler_zyx,
+ 'XZY': _rmat_to_euler_xzy,
+ 'YZX': _rmat_to_euler_yzx,
+ 'ZXY': _rmat_to_euler_zxy,
+ 'YXZ': _rmat_to_euler_yxz
+}
+
+
+def euler_to_quat(euler_vec, ordering='XYZ'):
+ """Returns the quaternion corresponding to the provided euler angles.
+
+ Args:
+ euler_vec: The euler angle rotations.
+ ordering: (str) Desired euler angle ordering.
+
+ Returns:
+ quat: A quaternion [w, i, j, k]
+ """
+ mat = euler_to_rmat(euler_vec, ordering=ordering)
+ return mat_to_quat(mat)
+
+
+def euler_to_rmat(euler_vec, ordering='ZXZ', full=False):
+ """Returns rotation matrix (or transform) for the given Euler rotations.
+
+ Euler*** methods compose a Rotation matrix corresponding to the given
+ rotations r1, r2, r3 following the given rotation ordering. Ordering
+ specifies the order of rotation matrices in matrix multiplication order.
+ E.g. for XYZ we return rotX(r1) * rotY(r2) * rotZ(r3).
+
+ Args:
+ euler_vec: The euler angle rotations.
+ ordering: euler angle ordering string (see _euler_orderings).
+ full: If true, returns a full 4x4 transfom.
+
+ Returns:
+ The rotation matrix or homogenous transform corresponding to the given
+ Euler rotation.
+ """
+
+ # map from partial rotation orderings to rotation functions
+ rotmap = {'X': rotation_x_axis, 'Y': rotation_y_axis, 'Z': rotation_z_axis}
+ rotations = [rotmap[c] for c in ordering]
+
+ euler_vec = np.atleast_2d(euler_vec)
+
+ rots = []
+ for i in range(len(rotations)):
+ rots.append(rotations[i](euler_vec[:, i], full))
+
+ if rots[0].ndim == 3:
+ result = _batch_mm(_batch_mm(rots[0], rots[1]), rots[2])
+ return result.squeeze()
+ else:
+ return (rots[0].dot(rots[1])).dot(rots[2])
+
+
+def quat_conj(quat):
+ """Return conjugate of quaternion.
+
+ This function supports inputs with or without leading batch dimensions.
+
+ Args:
+ quat: A quaternion [w, i, j, k].
+
+ Returns:
+ A quaternion [w, -i, -j, -k] representing the inverse of the rotation
+ defined by `quat` (not assuming normalization).
+ """
+ # Ensure quat is an np.array in case a tuple or a list is passed
+ quat = np.asarray(quat)
+ return np.stack(
+ [quat[..., 0], -quat[..., 1],
+ -quat[..., 2], -quat[..., 3]], axis=-1).astype(np.float64)
+
+
+def quat_inv(quat):
+ """Return inverse of quaternion.
+
+ This function supports inputs with or without leading batch dimensions.
+
+ Args:
+ quat: A quaternion [w, i, j, k].
+
+ Returns:
+ A quaternion representing the inverse of the original rotation.
+ """
+ # Ensure quat is an np.array in case a tuple or a list is passed
+ quat = np.asarray(quat)
+ return quat_conj(quat) / np.sum(quat * quat, axis=-1, keepdims=True)
+
+
+def _get_qmat_indices_and_signs():
+ """Precomputes index and sign arrays for constructing `qmat` in `quat_mul`."""
+ w, x, y, z = range(4)
+ qmat_idx_and_sign = np.array([
+ [w, -x, -y, -z],
+ [x, w, -z, y],
+ [y, z, w, -x],
+ [z, -y, x, w],
+ ])
+ indices = np.abs(qmat_idx_and_sign)
+ signs = 2 * (qmat_idx_and_sign >= 0) - 1
+ # Prevent array constants from being modified in place.
+ indices.flags.writeable = False
+ signs.flags.writeable = False
+ return indices, signs
+
+_qmat_idx, _qmat_sign = _get_qmat_indices_and_signs()
+
+
+def quat_mul(quat1, quat2):
+ """Computes the Hamilton product of two quaternions.
+
+ Any number of leading batch dimensions is supported.
+
+ Args:
+ quat1: A quaternion [w, i, j, k].
+ quat2: A quaternion [w, i, j, k].
+
+ Returns:
+ The quaternion product quat1 * quat2.
+ """
+ # Construct a (..., 4, 4) matrix to multiply with quat2 as shown below.
+ qmat = quat1[..., _qmat_idx] * _qmat_sign
+
+ # Compute the batched Hamilton product:
+ # |w1 -i1 -j1 -k1| |w2| |w1w2 - i1i2 - j1j2 - k1k2|
+ # |i1 w1 -k1 j1| . |i2| = |w1i2 + i1w2 + j1k2 - k1j2|
+ # |j1 k1 w1 -i1| |j2| |w1j2 - i1k2 + j1w2 + k1i2|
+ # |k1 -j1 i1 w1| |k2| |w1k2 + i1j2 - j1i2 + k1w2|
+ return (qmat @ quat2[..., None])[..., 0]
+
+
+def quat_diff(source, target):
+ """Computes quaternion difference between source and target quaternions.
+
+ This function supports inputs with or without leading batch dimensions.
+
+ Args:
+ source: A quaternion [w, i, j, k].
+ target: A quaternion [w, i, j, k].
+
+ Returns:
+ A quaternion representing the rotation from source to target.
+ """
+ return quat_mul(quat_conj(source), target)
+
+
+def quat_log(quat, tol=_TOL):
+ """Log of a quaternion.
+
+ This function supports inputs with or without leading batch dimensions.
+
+ Args:
+ quat: A quaternion [w, i, j, k].
+ tol: numerical tolerance to prevent nan.
+
+ Returns:
+ 4D array representing the log of `quat`. This is analogous to
+ `rmat_to_axisangle`.
+ """
+ # Ensure quat is an np.array in case a tuple or a list is passed
+ quat = np.asarray(quat)
+ q_norm = np.linalg.norm(quat + tol, axis=-1, keepdims=True)
+ a = quat[..., 0:1]
+ v = np.stack([quat[..., 1], quat[..., 2], quat[..., 3]], axis=-1)
+ # Clip to 2*tol because we subtract it here
+ v_new = v / np.linalg.norm(v + tol, axis=-1, keepdims=True) * np.arccos(
+ _clip_within_precision(a - tol, -1., 1., precision=2.*tol)) / q_norm
+ return np.stack(
+ [np.log(q_norm[..., 0]), v_new[..., 0], v_new[..., 1], v_new[..., 2]],
+ axis=-1)
+
+
+def quat_dist(source, target):
+ """Computes distance between source and target quaternions.
+
+ This function assumes that both input arguments are unit quaternions.
+
+ This function supports inputs with or without leading batch dimensions.
+
+ Args:
+ source: A quaternion [w, i, j, k].
+ target: A quaternion [w, i, j, k].
+
+ Returns:
+ Scalar representing the rotational distance from source to target.
+ """
+ quat_product = quat_mul(source, quat_inv(target))
+ quat_product /= np.linalg.norm(quat_product, axis=-1, keepdims=True)
+ return np.linalg.norm(quat_log(quat_product), axis=-1, keepdims=True)
+
+
+def quat_rotate(quat, vec):
+ """Rotate a vector by a quaternion.
+
+ Args:
+ quat: A quaternion [w, i, j, k].
+ vec: A 3-vector representing a position.
+
+ Returns:
+ The rotated vector.
+ """
+ qvec = np.hstack([[0], vec])
+ return quat_mul(quat_mul(quat, qvec), quat_conj(quat))[1:]
+
+
+def quat_to_axisangle(quat):
+ """Returns the axis-angle corresponding to the provided quaternion.
+
+ Args:
+ quat: A quaternion [w, i, j, k].
+
+ Returns:
+ axisangle: A 3x1 numpy array describing the axis of rotation, with angle
+ encoded by its length.
+ """
+ angle = 2 * np.arccos(_clip_within_precision(quat[0], -1., 1.))
+
+ if angle < _TOL:
+ return np.zeros(3)
+ else:
+ qn = np.sin(angle/2)
+ angle = (angle + np.pi) % (2 * np.pi) - np.pi
+ axis = quat[1:4] / qn
+ return axis * angle
+
+
+def quat_to_euler(quat, ordering='XYZ'):
+ """Returns the euler angles corresponding to the provided quaternion.
+
+ Args:
+ quat: A quaternion [w, i, j, k].
+ ordering: (str) Desired euler angle ordering.
+
+ Returns:
+ euler_vec: The euler angle rotations.
+ """
+ mat = quat_to_mat(quat)
+ return rmat_to_euler(mat[0:3, 0:3], ordering=ordering)
+
+
+def quat_to_mat(quat):
+ """Return homogeneous rotation matrix from quaternion.
+
+ Args:
+ quat: A quaternion [w, i, j, k].
+
+ Returns:
+ A 4x4 homogeneous matrix with the rotation corresponding to `quat`.
+ """
+ q = np.array(quat, dtype=np.float64, copy=True)
+ nq = np.dot(q, q)
+ if nq < _TOL:
+ return np.identity(4)
+ q *= np.sqrt(2.0 / nq)
+ q = np.outer(q, q)
+ return np.array(
+ ((1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0),
+ (q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0),
+ (q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0),
+ (0.0, 0.0, 0.0, 1.0)),
+ dtype=np.float64)
+
+
+def rotation_x_axis(theta, full=False):
+ """Returns a rotation matrix of a rotation about the X-axis.
+
+ Supports vector-valued theta, in which case the returned array is of shape
+ (len(t), 3, 3), or (len(t), 4, 4) if full=True. If theta is scalar the batch
+ dimension is squeezed out.
+
+ Args:
+ theta: The rotation amount.
+ full: (bool) If true, returns a full 4x4 transform.
+ """
+ n, ct, st, rmat = _axis_rotation(theta, full)
+
+ rmat[:, 0, 0:3] = np.array([[1, 0, 0]])
+ rmat[:, 1, 0:3] = np.vstack([np.zeros(n), ct, -st]).T
+ rmat[:, 2, 0:3] = np.vstack([np.zeros(n), st, ct]).T
+
+ return rmat.squeeze()
+
+
+def rotation_y_axis(theta, full=False):
+ """Returns a rotation matrix of a rotation about the Y-axis.
+
+ Supports vector-valued theta, in which case the returned array is of shape
+ (len(t), 3, 3), or (len(t), 4, 4) if full=True. If theta is scalar the batch
+ dimension is squeezed out.
+
+ Args:
+ theta: The rotation amount.
+ full: (bool) If true, returns a full 4x4 transfom.
+ """
+ n, ct, st, rmat = _axis_rotation(theta, full)
+
+ rmat[:, 0, 0:3] = np.vstack([ct, np.zeros(n), st]).T
+ rmat[:, 1, 0:3] = np.array([[0, 1, 0]])
+ rmat[:, 2, 0:3] = np.vstack([-st, np.zeros(n), ct]).T
+
+ return rmat.squeeze()
+
+
+def rotation_z_axis(theta, full=False):
+ """Returns a rotation matrix of a rotation about the z-axis.
+
+ Supports vector-valued theta, in which case the returned array is of shape
+ (len(t), 3, 3), or (len(t), 4, 4) if full=True. If theta is scalar the batch
+ dimension is squeezed out.
+
+ Args:
+ theta: The rotation amount.
+ full: (bool) If true, returns a full 4x4 transfom.
+ """
+ n, ct, st, rmat = _axis_rotation(theta, full)
+
+ rmat[:, 0, 0:3] = np.vstack([ct, -st, np.zeros(n)]).T
+ rmat[:, 1, 0:3] = np.vstack([st, ct, np.zeros(n)]).T
+ rmat[:, 2, 0:3] = np.array([[0, 0, 1]])
+
+ return rmat.squeeze()
+
+
+def rmat_to_euler(rmat, ordering='ZXZ'):
+ """Returns the euler angles corresponding to the provided rotation matrix.
+
+ Args:
+ rmat: The rotation matrix.
+ ordering: (str) Desired euler angle ordering.
+
+ Returns:
+ Euler angles corresponding to the provided rotation matrix.
+ """
+ return _eulermap[ordering](rmat)
+
+
+def mat_to_quat(mat):
+ """Return quaternion from homogeneous or rotation matrix.
+
+ Args:
+ mat: A homogeneous transform or rotation matrix
+
+ Returns:
+ A quaternion [w, i, j, k].
+ """
+ if mat.shape == (3, 3):
+ tmp = np.eye(4)
+ tmp[0:3, 0:3] = mat
+ mat = tmp
+
+ q = np.empty((4,), dtype=np.float64)
+ t = np.trace(mat)
+ if t > mat[3, 3]:
+ q[0] = t
+ q[3] = mat[1, 0] - mat[0, 1]
+ q[2] = mat[0, 2] - mat[2, 0]
+ q[1] = mat[2, 1] - mat[1, 2]
+ else:
+ i, j, k = 0, 1, 2
+ if mat[1, 1] > mat[0, 0]:
+ i, j, k = 1, 2, 0
+ if mat[2, 2] > mat[i, i]:
+ i, j, k = 2, 0, 1
+ t = mat[i, i] - (mat[j, j] + mat[k, k]) + mat[3, 3]
+ q[i + 1] = t
+ q[j + 1] = mat[i, j] + mat[j, i]
+ q[k + 1] = mat[k, i] + mat[i, k]
+ q[0] = mat[k, j] - mat[j, k]
+ q *= 0.5 / np.sqrt(t * mat[3, 3])
+ return q
+
+
+def axisangle_to_quat(axisangle, tol=0.0):
+ """Returns the quaternion corresponding to the provided axis-angle.
+
+ Args:
+ axisangle: A 3x1 numpy array describing the axis of rotation, with angle
+ encoded by its length.
+ tol: Tolerance for the angle magnitude below which the identity quaternion
+ is returned.
+
+ Returns:
+ A quaternion [w, i, j, k].
+ """
+ axisangle = np.asarray(axisangle)
+ angle = np.linalg.norm(axisangle, axis=-1, keepdims=True)
+ axis = np.where(angle <= tol, [1.0, 0.0, 0.0], axisangle / angle)
+ angle = np.where(angle <= tol, [0.0], angle)
+ sine, cosine = np.sin(angle / 2), np.cos(angle / 2)
+ return np.concatenate([cosine, axis * sine], axis=-1)
+
+
+# ################
+# # 2D Functions #
+# ################
+
+
+def rotation_matrix_2d(theta):
+ ct = np.cos(theta)
+ st = np.sin(theta)
+ return np.array([
+ [ct, -st],
+ [st, ct]
+ ])
diff --git a/dm_control/utils/transformations_test.py b/dm_control/utils/transformations_test.py
new file mode 100644
index 00000000..b11a5694
--- /dev/null
+++ b/dm_control/utils/transformations_test.py
@@ -0,0 +1,293 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import itertools
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.utils import transformations
+
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+_NUM_RANDOM_SAMPLES = 1000
+
+
+class TransformationsTest(parameterized.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._random_state = np.random.RandomState()
+
+ @parameterized.parameters(
+ {
+ 'quat': [-0.41473841, 0.59483601, -0.45089078, 0.52044181],
+ 'truemat':
+ np.array([[0.05167565, -0.10471773, 0.99315851],
+ [-0.96810656, -0.24937912, 0.02407785],
+ [0.24515162, -0.96272751, -0.11426475]])
+ },
+ {
+ 'quat': [0.08769298, 0.69897558, 0.02516888, 0.7093022],
+ 'truemat':
+ np.array([[-0.00748615, -0.08921678, 0.9959841],
+ [0.15958651, -0.98335294, -0.08688582],
+ [0.98715556, 0.15829519, 0.02159933]])
+ },
+ {
+ 'quat': [0.58847272, 0.44682507, 0.51443343, -0.43520737],
+ 'truemat':
+ np.array([[0.09190557, 0.97193884, 0.21653695],
+ [-0.05249182, 0.22188379, -0.97365918],
+ [-0.99438321, 0.07811829, 0.07141119]])
+ },
+ )
+ def test_quat_to_mat(self, quat, truemat):
+ """Tests hard-coded quat-mat pairs generated from mujoco if mj not avail."""
+ mat = transformations.quat_to_mat(quat)
+ np.testing.assert_allclose(mat[0:3, 0:3], truemat, atol=1e-7)
+
+ def test_quat_to_mat_mujoco_special(self):
+ # Test for special values that often cause numerical issues.
+ rng = [-np.pi, np.pi / 2, 0, np.pi / 2, np.pi]
+ for euler_tup in itertools.product(rng, rng, rng):
+ euler_vec = np.array(euler_tup, dtype=float)
+ mat = transformations.euler_to_rmat(euler_vec, ordering='XYZ')
+ quat = transformations.mat_to_quat(mat)
+ tr_mat = transformations.quat_to_mat(quat)
+ mj_mat = np.zeros(9)
+ mjlib.mju_quat2Mat(mj_mat, quat)
+ mj_mat = mj_mat.reshape(3, 3)
+ np.testing.assert_allclose(tr_mat[0:3, 0:3], mj_mat, atol=1e-10)
+ np.testing.assert_allclose(tr_mat[0:3, 0:3], mat, atol=1e-10)
+
+ def test_quat_to_mat_mujoco_random(self):
+ for _ in range(_NUM_RANDOM_SAMPLES):
+ quat = self._random_quaternion()
+ tr_mat = transformations.quat_to_mat(quat)
+ mj_mat = np.zeros(9)
+ mjlib.mju_quat2Mat(mj_mat, quat)
+ mj_mat = mj_mat.reshape(3, 3)
+ np.testing.assert_allclose(tr_mat[0:3, 0:3], mj_mat)
+
+ def test_mat_to_quat_mujoco(self):
+ subsamps = 10
+ rng = np.linspace(-np.pi, np.pi, subsamps)
+ for euler_tup in itertools.product(rng, rng, rng):
+ euler_vec = np.array(euler_tup, dtype=float)
+ mat = transformations.euler_to_rmat(euler_vec, ordering='XYZ')
+ mj_quat = np.empty(4, dtype=euler_vec.dtype)
+ mjlib.mju_mat2Quat(mj_quat, mat.flatten())
+ tr_quat = transformations.mat_to_quat(mat)
+ self.assertTrue(
+ np.allclose(mj_quat, tr_quat) or np.allclose(mj_quat, -tr_quat))
+
+ @parameterized.parameters(
+ {'angles': (0, 0, 0)},
+ {'angles': (-0.1, 0.4, -1.3)}
+ )
+ def test_euler_to_rmat_special(self, angles):
+ # Test for special values that often cause numerical issues.
+ r1, r2, r3 = angles
+ for ordering in transformations._eulermap.keys():
+ r = transformations.euler_to_rmat(np.array([r1, r2, r3]), ordering)
+ euler_angles = transformations.rmat_to_euler(r, ordering)
+ np.testing.assert_allclose(euler_angles, [r1, r2, r3])
+
+ def test_quat_mul_vs_mat_mul_random(self):
+ for _ in range(_NUM_RANDOM_SAMPLES):
+ quat1 = self._random_quaternion()
+ quat2 = self._random_quaternion()
+ rmat1 = transformations.quat_to_mat(quat1)[0:3, 0:3]
+ rmat2 = transformations.quat_to_mat(quat2)[0:3, 0:3]
+ quat_prod = transformations.quat_mul(quat1, quat2)
+ rmat_prod_q = transformations.quat_to_mat(quat_prod)[0:3, 0:3]
+ rmat_prod = rmat1.dot(rmat2)
+ np.testing.assert_allclose(rmat_prod, rmat_prod_q)
+
+ def test_quat_mul_vs_mat_mul_random_batched(self):
+ quat1 = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ quat2 = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ quat_prod = transformations.quat_mul(quat1, quat2)
+ for k in range(_NUM_RANDOM_SAMPLES):
+ rmat1 = transformations.quat_to_mat(quat1[k])[0:3, 0:3]
+ rmat2 = transformations.quat_to_mat(quat2[k])[0:3, 0:3]
+ rmat_prod_q = transformations.quat_to_mat(quat_prod[k])[0:3, 0:3]
+ rmat_prod = rmat1.dot(rmat2)
+ np.testing.assert_allclose(rmat_prod, rmat_prod_q)
+
+ def test_quat_mul_mujoco_special(self):
+ # Test for special values that often cause numerical issues.
+ rng = [-np.pi, np.pi / 2, 0, np.pi / 2, np.pi]
+ quat1 = np.array([1, 0, 0, 0], dtype=np.float64)
+ for euler_tup in itertools.product(rng, rng, rng):
+ euler_vec = np.array(euler_tup, dtype=np.float64)
+ quat2 = transformations.euler_to_quat(euler_vec, ordering='XYZ')
+ quat_prod_tr = transformations.quat_mul(quat1, quat2)
+ quat_prod_mj = np.zeros(4)
+ mjlib.mju_mulQuat(quat_prod_mj, quat1, quat2)
+ np.testing.assert_allclose(quat_prod_tr, quat_prod_mj, atol=1e-14)
+ quat1 = quat2
+
+ def test_quat_mul_mujoco_special_batched(self):
+ # Test for special values that often cause numerical issues.
+ rng = [-np.pi, np.pi / 2, 0, np.pi / 2, np.pi]
+ q1, q2, qmj = [], [], []
+ quat1 = np.array([1, 0, 0, 0], dtype=np.float64)
+ for euler_tup in itertools.product(rng, rng, rng):
+ euler_vec = np.array(euler_tup, dtype=np.float64)
+ quat2 = transformations.euler_to_quat(euler_vec, ordering='XYZ')
+ quat_prod_mj = np.zeros(4)
+ mjlib.mju_mulQuat(quat_prod_mj, quat1, quat2)
+ q1.append(quat1)
+ q2.append(quat2)
+ qmj.append(quat_prod_mj)
+ quat1 = quat2
+ q1 = np.stack(q1, axis=0)
+ q2 = np.stack(q2, axis=0)
+ qmj = np.stack(qmj, axis=0)
+ qtr = transformations.quat_mul(q1, q2)
+ np.testing.assert_allclose(qtr, qmj, atol=1e-14)
+
+ def test_quat_mul_mujoco_random(self):
+ for _ in range(_NUM_RANDOM_SAMPLES):
+ quat1 = self._random_quaternion()
+ quat2 = self._random_quaternion()
+ quat_prod_tr = transformations.quat_mul(quat1, quat2)
+ quat_prod_mj = np.zeros(4)
+ mjlib.mju_mulQuat(quat_prod_mj, quat1, quat2)
+ np.testing.assert_allclose(quat_prod_tr, quat_prod_mj)
+
+ def test_quat_mul_mujoco_random_batched(self):
+ quat1 = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ quat2 = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ quat_prod_tr = transformations.quat_mul(quat1, quat2)
+ for k in range(quat1.shape[0]):
+ quat_prod_mj = np.zeros(4)
+ mjlib.mju_mulQuat(quat_prod_mj, quat1[k], quat2[k])
+ np.testing.assert_allclose(quat_prod_tr[k], quat_prod_mj)
+
+ def test_quat_rotate_mujoco_special(self):
+ # Test for special values that often cause numerical issues.
+ rng = [-np.pi, np.pi / 2, 0, np.pi / 2, np.pi]
+ vec = np.array([1, 0, 0], dtype=np.float64)
+ for euler_tup in itertools.product(rng, rng, rng):
+ euler_vec = np.array(euler_tup, dtype=np.float64)
+ quat = transformations.euler_to_quat(euler_vec, ordering='XYZ')
+ rotated_vec_tr = transformations.quat_rotate(quat, vec)
+ rotated_vec_mj = np.zeros(3)
+ mjlib.mju_rotVecQuat(rotated_vec_mj, vec, quat)
+ np.testing.assert_allclose(rotated_vec_tr, rotated_vec_mj, atol=1e-14)
+
+ def test_quat_rotate_mujoco_random(self):
+ for _ in range(_NUM_RANDOM_SAMPLES):
+ quat = self._random_quaternion()
+ vec = self._random_state.rand(3)
+ rotated_vec_tr = transformations.quat_rotate(quat, vec)
+ rotated_vec_mj = np.zeros(3)
+ mjlib.mju_rotVecQuat(rotated_vec_mj, vec, quat)
+ np.testing.assert_allclose(rotated_vec_tr, rotated_vec_mj)
+
+ def test_quat_diff_random(self):
+ for _ in range(_NUM_RANDOM_SAMPLES):
+ source = self._random_quaternion()
+ target = self._random_quaternion()
+ np.testing.assert_allclose(
+ transformations.quat_diff(source, target),
+ transformations.quat_mul(transformations.quat_conj(source), target))
+
+ def test_quat_diff_random_batched(self):
+ source = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ target = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ np.testing.assert_allclose(
+ transformations.quat_diff(source, target),
+ transformations.quat_mul(transformations.quat_conj(source), target))
+
+ def test_quat_dist_random(self):
+ for _ in range(_NUM_RANDOM_SAMPLES):
+ # test with normalized quaternions for stability of test
+ source = self._random_quaternion()
+ target = self._random_quaternion()
+ source /= np.linalg.norm(source)
+ target /= np.linalg.norm(target)
+ self.assertGreater(transformations.quat_dist(source, target), 0)
+ np.testing.assert_allclose(
+ transformations.quat_dist(source, source), 0, atol=1e-9)
+
+ def test_quat_dist_random_batched(self):
+ # Test batched quat dist
+ source_quats = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ target_quats = np.stack(
+ [self._random_quaternion() for _ in range(_NUM_RANDOM_SAMPLES)], axis=0)
+ source_quats /= np.linalg.norm(source_quats, axis=-1, keepdims=True)
+ target_quats /= np.linalg.norm(target_quats, axis=-1, keepdims=True)
+ np.testing.assert_allclose(
+ transformations.quat_dist(source_quats, source_quats), 0, atol=1e-9)
+ np.testing.assert_equal(
+ transformations.quat_dist(source_quats, target_quats) > 0, 1)
+
+ def _random_quaternion(self):
+ rand = self._random_state.rand(3)
+ r1 = np.sqrt(1.0 - rand[0])
+ r2 = np.sqrt(rand[0])
+ pi2 = np.pi * 2.0
+ t1 = pi2 * rand[1]
+ t2 = pi2 * rand[2]
+ return np.array(
+ (np.cos(t2) * r2, np.sin(t1) * r1, np.cos(t1) * r1, np.sin(t2) * r2),
+ dtype=np.float64)
+
+ def test_axisangle_to_quat(self):
+ axisangle = np.array([0.1, 0.2, 0.3])
+ quat = transformations.axisangle_to_quat(axisangle)
+ np.testing.assert_allclose(
+ quat, [0.982551, 0.0497088, 0.0994177, 0.1491265], atol=1e-6
+ )
+
+ def test_axisangle_to_quat_zero(self):
+ axisangle = np.array([0, 0, 0])
+ quat = transformations.axisangle_to_quat(axisangle)
+ np.testing.assert_allclose(quat, [1, 0, 0, 0])
+
+ def test_axisangle_to_quat_zero_tol(self):
+ axisangle = np.array([0, 0, 1e-2])
+ quat = transformations.axisangle_to_quat(axisangle, tol=1e-1)
+ np.testing.assert_allclose(quat, [1, 0, 0, 0])
+
+ def test_axisangle_to_quat_batched(self):
+ axisangle = np.stack([np.array([0.1, 0.2, 0.3]), np.array([0.4, 0.5, 0.6])])
+ quat = transformations.axisangle_to_quat(axisangle)
+ np.testing.assert_allclose(
+ quat,
+ [
+ [0.982551, 0.0497088, 0.0994177, 0.1491265],
+ [0.9052841, 0.1936448, 0.242056, 0.2904672],
+ ],
+ atol=1e-6,
+ )
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/utils/xml_tools.py b/dm_control/utils/xml_tools.py
index d0e04355..661cc684 100644
--- a/dm_control/utils/xml_tools.py
+++ b/dm_control/utils/xml_tools.py
@@ -15,14 +15,8 @@
"""Helper functions for model xml creation and modification."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import copy
-# Internal dependencies.
-
from lxml import etree
diff --git a/dm_control/utils/xml_tools_test.py b/dm_control/utils/xml_tools_test.py
index 28a03712..54ba0dfd 100644
--- a/dm_control/utils/xml_tools_test.py
+++ b/dm_control/utils/xml_tools_test.py
@@ -15,19 +15,13 @@
"""Tests for utils.xml_tools."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Internal dependencies.
+import io
from absl.testing import absltest
-
from dm_control.utils import xml_tools
+import lxml
-from lxml import etree
-
-import six
+etree = lxml.etree
class XmlHelperTest(absltest.TestCase):
@@ -47,7 +41,7 @@ def test_tostring(self):
"""
- tree = xml_tools.parse(six.StringIO(xml_str))
+ tree = xml_tools.parse(io.StringIO(xml_str))
self.assertEqual(b'\n \n \n \n\n',
etree.tostring(tree, pretty_print=True))
@@ -61,7 +55,7 @@ def test_find_element(self):
"""
- tree = xml_tools.parse(six.StringIO(xml_str))
+ tree = xml_tools.parse(io.StringIO(xml_str))
world = xml_tools.find_element(root=tree, tag='world', name='world_name')
self.assertEqual(world.tag, 'world')
self.assertEqual(world.attrib['name'], 'world_name')
@@ -70,10 +64,10 @@ def test_find_element(self):
self.assertEqual(geom.tag, 'geom')
self.assertEqual(geom.attrib['name'], 'geom_name')
- with self.assertRaisesRegexp(ValueError, 'Element with tag'):
+ with self.assertRaisesRegex(ValueError, 'Element with tag'):
xml_tools.find_element(root=tree, tag='does_not_exist', name='name')
- with self.assertRaisesRegexp(ValueError, 'Element with tag'):
+ with self.assertRaisesRegex(ValueError, 'Element with tag'):
xml_tools.find_element(root=tree, tag='world', name='does_not_exist')
diff --git a/dm_control/viewer/README.md b/dm_control/viewer/README.md
new file mode 100644
index 00000000..03227288
--- /dev/null
+++ b/dm_control/viewer/README.md
@@ -0,0 +1,67 @@
+# Interactive environment viewer
+
+# 
+
+The `dm_control.viewer` library can be used to visualize and interact with a
+control environment. The following example shows how to launch the viewer with
+an environment from the Control Suite:
+
+```python
+from dm_control import suite
+from dm_control import viewer
+
+# Load an environment from the Control Suite.
+env = suite.load(domain_name="humanoid", task_name="stand")
+
+# Launch the viewer application.
+viewer.launch(env)
+```
+
+For convenience we also provide a viewer launch script for the Control Suite in
+`dm_control/suite/explore.py`.
+
+## Viewing the environment with a policy in the loop
+
+The viewer is also capable of running the environment with a policy in the loop
+to provide actions. This is done by passing the optional `policy` argument to
+`viewer.launch`. The `policy` should be a callable that accepts a `TimeStep` and
+returns a numpy array of actions conforming to `environment.action_spec()`. The
+example below shows how to execute a random uniform policy using the viewer:
+
+```python
+from dm_control import suite
+from dm_control import viewer
+import numpy as np
+
+env = suite.load(domain_name="humanoid", task_name="stand")
+action_spec = env.action_spec()
+
+# Define a uniform random policy.
+def random_policy(time_step):
+ del time_step # Unused.
+ return np.random.uniform(low=action_spec.minimum,
+ high=action_spec.maximum,
+ size=action_spec.shape)
+
+# Launch the viewer application.
+viewer.launch(env, policy=random_policy)
+```
+
+## Keyboard and mouse controls
+
+The viewer contains a built in help screen that can be brought up by pressing
+`F1`. You will find a comprehensive description of keyboard and mouse controls
+there.
+
+## Status view
+
+Displays status of the simulation:
+
+- State - Current status of the Runtime state machine.
+- Time - Simulation clock accompanied by the current setting of time
+ multiplier.
+- CPU - How much time per frame does physics simulation consume.
+- FPS - How many frames per second is the application rendering.
+- Camera - Name of the active camera.
+- Paused - Is the simulation paused?
+- Error - Recently caught error message.
diff --git a/dm_control/viewer/__init__.py b/dm_control/viewer/__init__.py
new file mode 100644
index 00000000..6409fbac
--- /dev/null
+++ b/dm_control/viewer/__init__.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Suite environments viewer package."""
+
+
+from dm_control.viewer import application
+
+
+def launch(environment_loader, policy=None, title='Explorer', width=1024,
+ height=768):
+ """Launches an environment viewer.
+
+ Args:
+ environment_loader: An environment loader (a callable that returns an
+ instance of dm_control.rl.control.Environment), an instance of
+ dm_control.rl.control.Environment.
+ policy: An optional callable corresponding to a policy to execute within the
+ environment. It should accept a `TimeStep` and return a numpy array of
+ actions conforming to the output of `environment.action_spec()`.
+ title: Application title to be displayed in the title bar.
+ width: Window width, in pixels.
+ height: Window height, in pixels.
+ Raises:
+ ValueError: When 'environment_loader' argument is set to None.
+ """
+ app = application.Application(title=title, width=width, height=height)
+ app.launch(environment_loader=environment_loader, policy=policy)
diff --git a/dm_control/viewer/application.py b/dm_control/viewer/application.py
new file mode 100644
index 00000000..81ab1ae0
--- /dev/null
+++ b/dm_control/viewer/application.py
@@ -0,0 +1,333 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Viewer application module."""
+
+import collections
+
+from dm_control import _render
+from dm_control.viewer import gui
+from dm_control.viewer import renderer
+from dm_control.viewer import runtime
+from dm_control.viewer import user_input
+from dm_control.viewer import util
+from dm_control.viewer import viewer
+from dm_control.viewer import views
+
+_DOUBLE_BUFFERING = (user_input.KEY_F5)
+_PAUSE = user_input.KEY_SPACE
+_RESTART = user_input.KEY_BACKSPACE
+_ADVANCE_SIMULATION = user_input.KEY_RIGHT
+_SPEED_UP_TIME = user_input.KEY_EQUAL
+_SLOW_DOWN_TIME = user_input.KEY_MINUS
+_HELP = user_input.KEY_F1
+_STATUS = user_input.KEY_F2
+
+_MAX_FRONTBUFFER_SIZE = 2048
+_MISSING_STATUS_ENTRY = '--'
+_RUNTIME_STOPPED_LABEL = 'EPISODE TERMINATED - hit backspace to restart'
+_STATUS_LABEL = 'Status'
+_TIME_LABEL = 'Time'
+_CPU_LABEL = 'CPU'
+_FPS_LABEL = 'FPS'
+_CAMERA_LABEL = 'Camera'
+_PAUSED_LABEL = 'Paused'
+_ERROR_LABEL = 'Error'
+
+
+class Help(views.ColumnTextModel):
+ """Contains the description of input map employed in the application."""
+
+ def __init__(self):
+ """Instance initializer."""
+ self._value = [
+ ['Help', 'F1'],
+ ['Info', 'F2'],
+ ['Stereo', 'F5'],
+ ['Frame', 'F6'],
+ ['Label', 'F7'],
+ ['--------------', ''],
+ ['Pause', 'Space'],
+ ['Reset', 'BackSpace'],
+ ['Autoscale', 'Ctrl A'],
+ ['Geoms', '0 - 4'],
+ ['Sites', 'Shift 0 - 4'],
+ ['Speed Up', '='],
+ ['Slow Down', '-'],
+ ['Switch Cam', '[ ]'],
+ ['--------------', ''],
+ ['Translate', 'R drag'],
+ ['Rotate', 'L drag'],
+ ['Zoom', 'Scroll'],
+ ['Select', 'L dblclick'],
+ ['Center', 'R dblclick'],
+ ['Track', 'Ctrl R dblclick / Esc'],
+ ['Perturb', 'Ctrl [Shift] L/R drag'],
+ ]
+
+ def get_columns(self):
+ """Returns the text to display in two columns."""
+ return self._value
+
+
+class Status(views.ColumnTextModel):
+ """Monitors and returns the status of the application."""
+
+ def __init__(self, time_multiplier, pause, frame_timer):
+ """Instance initializer.
+
+ Args:
+ time_multiplier: Instance of util.TimeMultiplier.
+ pause: An observable pause subject, instance of util.ObservableFlag.
+ frame_timer: A Timer instance counting duration of frames.
+ """
+ self._runtime = None
+ self._time_multiplier = time_multiplier
+ self._camera = None
+ self._pause = pause
+ self._frame_timer = frame_timer
+ self._fps_counter = util.Integrator()
+ self._cpu_counter = util.Integrator()
+
+ self._value = collections.OrderedDict([
+ (_STATUS_LABEL, _MISSING_STATUS_ENTRY),
+ (_TIME_LABEL, _MISSING_STATUS_ENTRY),
+ (_CPU_LABEL, _MISSING_STATUS_ENTRY),
+ (_FPS_LABEL, _MISSING_STATUS_ENTRY),
+ (_CAMERA_LABEL, _MISSING_STATUS_ENTRY),
+ (_PAUSED_LABEL, _MISSING_STATUS_ENTRY),
+ (_ERROR_LABEL, _MISSING_STATUS_ENTRY),
+ ])
+
+ def set_camera(self, camera):
+ """Updates the active camera instance.
+
+ Args:
+ camera: Instance of renderer.SceneCamera.
+ """
+ self._camera = camera
+
+ def set_runtime(self, instance):
+ """Updates the active runtime instance.
+
+ Args:
+ instance: Instance of runtime.Base.
+ """
+ if self._runtime:
+ self._runtime.on_error -= self._on_error
+ self._runtime.on_episode_begin -= self._clear_error
+ self._runtime = instance
+ if self._runtime:
+ self._runtime.on_error += self._on_error
+ self._runtime.on_episode_begin += self._clear_error
+
+ def get_columns(self):
+ """Returns the text to display in two columns."""
+ if self._frame_timer.measured_time > 0:
+ self._fps_counter.value = 1. / self._frame_timer.measured_time
+ self._value[_FPS_LABEL] = '{0:.1f}'.format(self._fps_counter.value)
+
+ if self._runtime:
+ if self._runtime.state == runtime.State.STOPPED:
+ self._value[_STATUS_LABEL] = _RUNTIME_STOPPED_LABEL
+ else:
+ self._value[_STATUS_LABEL] = str(self._runtime.state)
+
+ self._cpu_counter.value = self._runtime.simulation_time
+
+ self._value[_TIME_LABEL] = '{0:.1f} ({1}x)'.format(
+ self._runtime.get_time(), str(self._time_multiplier))
+ self._value[_CPU_LABEL] = '{0:.2f}ms'.format(
+ self._cpu_counter.value * 1000.0)
+ else:
+ self._value[_STATUS_LABEL] = _MISSING_STATUS_ENTRY
+ self._value[_TIME_LABEL] = _MISSING_STATUS_ENTRY
+ self._value[_CPU_LABEL] = _MISSING_STATUS_ENTRY
+
+ if self._camera:
+ self._value[_CAMERA_LABEL] = self._camera.name
+ else:
+ self._value[_CAMERA_LABEL] = _MISSING_STATUS_ENTRY
+
+ self._value[_PAUSED_LABEL] = str(self._pause.value)
+
+ return list(self._value.items()) # For Python 2/3 compatibility.
+
+ def _clear_error(self):
+ self._value[_ERROR_LABEL] = _MISSING_STATUS_ENTRY
+
+ def _on_error(self, error_msg):
+ self._value[_ERROR_LABEL] = error_msg
+
+
+class ReloadParams(collections.namedtuple(
+ 'RefreshParams', ['zoom_to_scene'])):
+ """Parameters of a reload request."""
+
+
+class Application:
+ """Viewer application."""
+
+ def __init__(self, title='Explorer', width=1024, height=768):
+ """Instance initializer."""
+ self._render_surface = None
+ self._renderer = renderer.NullRenderer()
+ self._viewport = renderer.Viewport(width, height)
+ self._window = gui.RenderWindow(width, height, title)
+
+ self._pause_subject = util.ObservableFlag(True)
+ self._time_multiplier = util.TimeMultiplier(1.)
+ self._frame_timer = util.Timer()
+ self._viewer = viewer.Viewer(
+ self._viewport, self._window.mouse, self._window.keyboard)
+ self._viewer_layout = views.ViewportLayout()
+ self._status = Status(
+ self._time_multiplier, self._pause_subject, self._frame_timer)
+
+ self._runtime = None
+ self._environment_loader = None
+ self._environment = None
+ self._policy = None
+ self._deferred_reload_request = None
+
+ status_view_toggle = self._build_view_toggle(
+ views.ColumnTextView(self._status), views.PanelLocation.BOTTOM_LEFT)
+ help_view_toggle = self._build_view_toggle(
+ views.ColumnTextView(Help()), views.PanelLocation.TOP_RIGHT)
+ status_view_toggle()
+
+ self._input_map = user_input.InputMap(
+ self._window.mouse, self._window.keyboard)
+ self._input_map.bind(self._pause_subject.toggle, _PAUSE)
+ self._input_map.bind(self._time_multiplier.increase, _SPEED_UP_TIME)
+ self._input_map.bind(self._time_multiplier.decrease, _SLOW_DOWN_TIME)
+ self._input_map.bind(self._advance_simulation, _ADVANCE_SIMULATION)
+ self._input_map.bind(self._restart_runtime, _RESTART)
+ self._input_map.bind(help_view_toggle, _HELP)
+ self._input_map.bind(status_view_toggle, _STATUS)
+
+ def _on_reload(self, zoom_to_scene=False):
+ """Perform initialization related to Physics reload.
+
+ Reset the components that depend on a specific Physics class instance.
+
+ Args:
+ zoom_to_scene: Should the camera zoom to show the entire scene after the
+ reload is complete.
+ """
+ self._deferred_reload_request = ReloadParams(zoom_to_scene)
+ self._viewer.deinitialize()
+ self._status.set_camera(None)
+
+ def _perform_deferred_reload(self, params):
+ """Performs the deferred part of initialization related to Physics reload.
+
+ Args:
+ params: Deferred reload parameters, an instance of ReloadParams.
+ """
+ if self._render_surface:
+ self._render_surface.free()
+ if self._renderer:
+ self._renderer.release()
+ self._render_surface = _render.Renderer(
+ max_width=_MAX_FRONTBUFFER_SIZE, max_height=_MAX_FRONTBUFFER_SIZE)
+ self._renderer = renderer.OffScreenRenderer(
+ self._environment.physics.model, self._render_surface)
+ self._renderer.components += self._viewer_layout
+ self._viewer.initialize(
+ self._environment.physics, self._renderer, touchpad=False)
+ self._status.set_camera(self._viewer.camera)
+ if params.zoom_to_scene:
+ self._viewer.zoom_to_scene()
+
+ def _build_view_toggle(self, view, location):
+ def toggle():
+ if view in self._viewer_layout:
+ self._viewer_layout.remove(view)
+ else:
+ self._viewer_layout.add(view, location)
+ return toggle
+
+ def _tick(self):
+ """Handle GUI events until the main window is closed."""
+ if self._deferred_reload_request:
+ self._perform_deferred_reload(self._deferred_reload_request)
+ self._deferred_reload_request = None
+ time_elapsed = self._frame_timer.tick() * self._time_multiplier.get()
+ if self._runtime:
+ with self._viewer.perturbation.apply(self._pause_subject.value):
+ self._runtime.tick(time_elapsed, self._pause_subject.value)
+ self._viewer.render()
+
+ def _load_environment(self, zoom_to_scene):
+ """Loads a new environment."""
+ if self._runtime:
+ del self._runtime
+ self._runtime = None
+ self._environment = None
+ environment_instance = None
+ if self._environment_loader:
+ environment_instance = self._environment_loader()
+ if environment_instance:
+ self._environment = environment_instance
+ self._runtime = runtime.Runtime(
+ environment=self._environment, policy=self._policy)
+ self._runtime.on_physics_changed += lambda: self._on_reload(False)
+ self._status.set_runtime(self._runtime)
+ self._on_reload(zoom_to_scene=zoom_to_scene)
+
+ def _restart_runtime(self):
+ """Restarts the episode, resetting environment, model, and data."""
+ if self._runtime:
+ self._runtime.stop()
+ self._load_environment(zoom_to_scene=False)
+
+ if self._policy:
+ if hasattr(self._policy, 'reset'):
+ self._policy.reset()
+
+ def _advance_simulation(self):
+ if self._runtime:
+ self._runtime.single_step()
+
+ def launch(self, environment_loader, policy=None):
+ """Starts the viewer with the specified policy and environment.
+
+ Args:
+ environment_loader: Either a callable that takes no arguments and returns
+ an instance of dm_control.rl.control.Environment, or an instance of
+ dm_control.rl.control.Environment.
+ policy: An optional callable corresponding to a policy to execute
+ within the environment. It should accept a `TimeStep` and return
+ a numpy array of actions conforming to the output of
+ `environment.action_spec()`. If the callable implements a method `reset`
+ then this method is called when the viewer is reset.
+
+ Raises:
+ ValueError: If `environment_loader` is None.
+ """
+ if environment_loader is None:
+ raise ValueError('"environment_loader" argument is required.')
+ if callable(environment_loader):
+ self._environment_loader = environment_loader
+ else:
+ self._environment_loader = lambda: environment_loader
+ self._policy = policy
+ self._load_environment(zoom_to_scene=True)
+ def tick():
+ self._viewport.set_size(*self._window.shape)
+ self._tick()
+ return self._renderer.pixels
+ self._window.event_loop(tick_func=tick)
+ self._window.close()
diff --git a/dm_control/viewer/application_test.py b/dm_control/viewer/application_test.py
new file mode 100644
index 00000000..0ea1d5e8
--- /dev/null
+++ b/dm_control/viewer/application_test.py
@@ -0,0 +1,92 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests of the application.py module."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.viewer import application
+import dm_env
+from dm_env import specs
+import mock
+import numpy as np
+
+
+class ApplicationTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ with mock.patch(application.__name__ + '.gui'):
+ self.app = application.Application()
+
+ self.app._viewer = mock.MagicMock()
+ self.app._keyboard_action = mock.MagicMock()
+
+ self.environment = mock.MagicMock(spec=dm_env.Environment)
+ self.environment.action_spec.return_value = specs.BoundedArray(
+ (1,), np.float64, -1, 1)
+ self.environment.physics = mock.MagicMock()
+ self.app._environment = self.environment
+ self.agent = mock.MagicMock()
+ self.loader = lambda: self.environment
+
+ def test_on_reload_defers_viewer_initialization_until_tick(self):
+ self.app._on_reload(zoom_to_scene=True)
+ self.assertEqual(0, self.app._viewer.initialize.call_count)
+ self.assertIsNotNone(self.app._deferred_reload_request)
+
+ def test_deferred_on_reload_parameters(self):
+ self.app._on_reload(zoom_to_scene=True)
+ self.assertTrue(self.app._deferred_reload_request.zoom_to_scene)
+ self.app._on_reload(zoom_to_scene=False)
+ self.assertFalse(self.app._deferred_reload_request.zoom_to_scene)
+
+ def test_executing_deferred_initialization(self):
+ self.app._deferred_reload_request = application.ReloadParams(False)
+ self.app._tick()
+ self.app._viewer.initialize.assert_called_once()
+
+ def test_processing_zoom_to_scene_request(self):
+ self.app._perform_deferred_reload(application.ReloadParams(True))
+ self.app._viewer.zoom_to_scene.assert_called_once()
+
+ def test_skipping_zoom_to_scene(self):
+ self.app._perform_deferred_reload(application.ReloadParams(False))
+ self.app._viewer.zoom_to_scene.assert_not_called()
+
+ def test_on_reload_deinitializes_viewer_instantly(self):
+ self.app._on_reload()
+ self.app._viewer.deinitialize.assert_called_once()
+
+ def test_zoom_to_scene_after_launch(self):
+ self.app.launch(self.loader, self.agent)
+ self.app._viewer.zoom_to_scene()
+
+ def test_tick_runtime(self):
+ self.app._runtime = mock.MagicMock()
+ self.app._pause_subject.value = False
+ self.app._tick()
+ self.app._runtime.tick.assert_called_once()
+
+ def test_restart_runtime(self):
+ self.app._load_environment = mock.MagicMock()
+ self.app._runtime = mock.MagicMock()
+ self.app._restart_runtime()
+ self.app._runtime.stop.assert_called_once()
+ self.app._load_environment.assert_called_once()
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/gui/__init__.py b/dm_control/viewer/gui/__init__.py
new file mode 100644
index 00000000..2b8012a8
--- /dev/null
+++ b/dm_control/viewer/gui/__init__.py
@@ -0,0 +1,38 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Viewer's windowing systems."""
+
+from dm_control import _render
+
+# pylint: disable=g-import-not-at-top
+# pylint: disable=invalid-name
+
+RenderWindow = None
+
+try:
+ from dm_control.viewer.gui import glfw_gui
+ RenderWindow = glfw_gui.GlfwWindow
+except ImportError:
+ pass
+
+if RenderWindow is None:
+
+ def ErrorRenderWindow(*args, **kwargs):
+ del args, kwargs
+ raise ImportError(
+ 'Cannot create a window because no windowing system could be imported')
+ RenderWindow = ErrorRenderWindow
+
+del _render
diff --git a/dm_control/viewer/gui/base.py b/dm_control/viewer/gui/base.py
new file mode 100644
index 00000000..932b1471
--- /dev/null
+++ b/dm_control/viewer/gui/base.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Utilities and base classes used exclusively in the gui package."""
+
+import abc
+import threading
+import time
+
+from dm_control.viewer import user_input
+
+_DOUBLE_CLICK_INTERVAL = 0.25 # seconds
+
+
+class InputEventsProcessor(metaclass=abc.ABCMeta):
+ """Thread safe input events processor."""
+
+ def __init__(self):
+ """Instance initializer."""
+ self._lock = threading.RLock()
+ self._events = []
+
+ def add_event(self, receivers, *args):
+ """Adds a new event to the processing queue."""
+ if not all(callable(receiver) for receiver in receivers):
+ raise TypeError('Receivers are expected to be callables.')
+ def event():
+ for receiver in list(receivers):
+ receiver(*args)
+ with self._lock:
+ self._events.append(event)
+
+ def process_events(self):
+ """Invokes each of the events in the queue.
+
+ Thread safe for queue access but not during event invocations.
+
+ This method must be called regularly on the main thread.
+ """
+ with self._lock:
+ # Swap event buffers quickly so that we don't block the input thread for
+ # too long.
+ events_to_process = self._events
+ self._events = []
+
+ # Now that we made the swap, process the received events in our own time.
+ for event in events_to_process:
+ event()
+
+
+class DoubleClickDetector:
+ """Detects double click events."""
+
+ def __init__(self):
+ self._double_clicks = {}
+
+ def process(self, button, action):
+ """Attempts to identify a mouse button click as a double click event."""
+ if action != user_input.PRESS:
+ return False
+
+ curr_time = time.time()
+ timestamp = self._double_clicks.get(button, None)
+ if timestamp is None:
+ # No previous click registered.
+ self._double_clicks[button] = curr_time
+ return False
+ else:
+ time_elapsed = curr_time - timestamp
+ if time_elapsed < _DOUBLE_CLICK_INTERVAL:
+ # Double click discovered.
+ self._double_clicks[button] = None
+ return True
+ else:
+ # The previous click was too long ago, so discard it and start a fresh
+ # timer.
+ self._double_clicks[button] = curr_time
+ return False
diff --git a/dm_control/viewer/gui/base_test.py b/dm_control/viewer/gui/base_test.py
new file mode 100644
index 00000000..39f89db0
--- /dev/null
+++ b/dm_control/viewer/gui/base_test.py
@@ -0,0 +1,68 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the base windowing system."""
+
+
+from absl.testing import absltest
+from dm_control.viewer import user_input
+import mock
+
+
+# pylint: disable=g-import-not-at-top
+_OPEN_GL_MOCK = mock.MagicMock()
+_MOCKED_MODULES = {
+ 'OpenGL': _OPEN_GL_MOCK,
+ 'OpenGL.GL': _OPEN_GL_MOCK,
+}
+with mock.patch.dict('sys.modules', _MOCKED_MODULES):
+ from dm_control.viewer.gui import base
+# pylint: enable=g-import-not-at-top
+
+_EPSILON = 1e-7
+
+
+@mock.patch.object(base, 'time')
+class DoubleClickDetectorTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.detector = base.DoubleClickDetector()
+ self.double_click_event = (user_input.MOUSE_BUTTON_LEFT, user_input.PRESS)
+
+ def test_two_rapid_clicks_yield_double_click_event(self, mock_time):
+ mock_time.time.return_value = 0
+ self.assertFalse(self.detector.process(*self.double_click_event))
+
+ mock_time.time.return_value = base._DOUBLE_CLICK_INTERVAL - _EPSILON
+ self.assertTrue(self.detector.process(*self.double_click_event))
+
+ def test_two_slow_clicks_dont_yield_double_click_event(self, mock_time):
+ mock_time.time.return_value = 0
+ self.assertFalse(self.detector.process(*self.double_click_event))
+
+ mock_time.time.return_value = base._DOUBLE_CLICK_INTERVAL
+ self.assertFalse(self.detector.process(*self.double_click_event))
+
+ def test_sequence_of_slow_clicks_followed_by_fast_click(self, mock_time):
+ click_times = [(0., False),
+ (base._DOUBLE_CLICK_INTERVAL * 2., False),
+ (base._DOUBLE_CLICK_INTERVAL * 3. - _EPSILON, True)]
+ for click_time, result in click_times:
+ mock_time.time.return_value = click_time
+ self.assertEqual(result, self.detector.process(*self.double_click_event))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/gui/fullscreen_quad.py b/dm_control/viewer/gui/fullscreen_quad.py
new file mode 100644
index 00000000..ef17ac61
--- /dev/null
+++ b/dm_control/viewer/gui/fullscreen_quad.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""OpenGL utility for rendering numpy arrays as images on a quad surface."""
+
+import ctypes
+import numpy as np
+from OpenGL import GL
+from OpenGL.GL import shaders
+
+# This array contains packed position and texture cooridnates of a fullscreen
+# quad.
+# It contains definition of 4 vertices that will be rendered as a triangle
+# strip. Each vertex is described by a tuple:
+# (VertexPosition.X, VertexPosition.Y, TextureCoord.U, TextureCoord.V)
+_FULLSCREEN_QUAD_VERTEX_POSITONS_AND_TEXTURE_COORDS = np.array([
+ -1, -1, 0, 1,
+ -1, 1, 0, 0,
+ 1, -1, 1, 1,
+ 1, 1, 1, 0], dtype=np.float32)
+_FLOATS_PER_XY = 2
+_FLOATS_PER_VERTEX = 4
+_SIZE_OF_FLOAT = ctypes.sizeof(ctypes.c_float)
+
+_VERTEX_SHADER = """
+#version 120
+attribute vec2 position;
+attribute vec2 uv;
+void main() {
+ gl_Position = vec4(position, 0, 1);
+ gl_TexCoord[0].st = uv;
+}
+"""
+_FRAGMENT_SHADER = """
+#version 120
+uniform sampler2D tex;
+void main() {
+ gl_FragColor = texture2D(tex, gl_TexCoord[0].st);
+}
+"""
+_VAR_POSITION = 'position'
+_VAR_UV = 'uv'
+_VAR_TEXTURE_SAMPLER = 'tex'
+
+
+class FullscreenQuadRenderer:
+ """Renders pixmaps on a fullscreen quad using OpenGL."""
+
+ def __init__(self):
+ """Initializes the fullscreen quad renderer."""
+ GL.glClearColor(0, 0, 0, 0)
+ self._init_geometry()
+ self._init_texture()
+ self._init_shaders()
+
+ def _init_geometry(self):
+ """Initializes the fullscreen quad geometry."""
+ vertex_buffer = GL.glGenBuffers(1)
+ GL.glBindBuffer(GL.GL_ARRAY_BUFFER, vertex_buffer)
+ GL.glBufferData(
+ GL.GL_ARRAY_BUFFER,
+ _FULLSCREEN_QUAD_VERTEX_POSITONS_AND_TEXTURE_COORDS.nbytes,
+ _FULLSCREEN_QUAD_VERTEX_POSITONS_AND_TEXTURE_COORDS, GL.GL_STATIC_DRAW)
+
+ def _init_texture(self):
+ """Initializes the texture storage."""
+ self._texture = GL.glGenTextures(1)
+ GL.glBindTexture(GL.GL_TEXTURE_2D, self._texture)
+ GL.glTexParameteri(
+ GL.GL_TEXTURE_2D, GL.GL_TEXTURE_MAG_FILTER, GL.GL_NEAREST)
+ GL.glTexParameteri(
+ GL.GL_TEXTURE_2D, GL.GL_TEXTURE_MIN_FILTER, GL.GL_NEAREST)
+
+ def _init_shaders(self):
+ """Initializes the shaders used to render the textures fullscreen quad."""
+ vs = shaders.compileShader(_VERTEX_SHADER, GL.GL_VERTEX_SHADER)
+ fs = shaders.compileShader(_FRAGMENT_SHADER, GL.GL_FRAGMENT_SHADER)
+ self._shader = shaders.compileProgram(vs, fs)
+
+ stride = _FLOATS_PER_VERTEX * _SIZE_OF_FLOAT
+ var_position = GL.glGetAttribLocation(self._shader, _VAR_POSITION)
+ GL.glVertexAttribPointer(
+ var_position, 2, GL.GL_FLOAT, GL.GL_FALSE, stride, None)
+ GL.glEnableVertexAttribArray(var_position)
+
+ var_uv = GL.glGetAttribLocation(self._shader, _VAR_UV)
+ uv_offset = ctypes.c_void_p(_FLOATS_PER_XY * _SIZE_OF_FLOAT)
+ GL.glVertexAttribPointer(
+ var_uv, 2, GL.GL_FLOAT, GL.GL_FALSE, stride, uv_offset)
+ GL.glEnableVertexAttribArray(var_uv)
+
+ self._var_texture_sampler = GL.glGetUniformLocation(
+ self._shader, _VAR_TEXTURE_SAMPLER)
+
+ def render(self, pixmap, viewport_shape):
+ """Renders the pixmap on a fullscreen quad.
+
+ Args:
+ pixmap: A 3D numpy array of bytes (np.uint8), with dimensions
+ (width, height, 3).
+ viewport_shape: A tuple of two elements, (width, height).
+ """
+ GL.glClear(GL.GL_COLOR_BUFFER_BIT)
+ GL.glViewport(0, 0, *viewport_shape)
+ GL.glUseProgram(self._shader)
+ GL.glActiveTexture(GL.GL_TEXTURE0)
+ GL.glBindTexture(GL.GL_TEXTURE_2D, self._texture)
+ GL.glPixelStorei(GL.GL_UNPACK_ALIGNMENT, 1)
+ GL.glTexImage2D(GL.GL_TEXTURE_2D, 0, GL.GL_RGB, pixmap.shape[1],
+ pixmap.shape[0], 0, GL.GL_RGB, GL.GL_UNSIGNED_BYTE,
+ pixmap)
+ GL.glUniform1i(self._var_texture_sampler, 0)
+ GL.glDrawArrays(GL.GL_TRIANGLE_STRIP, 0, 4)
diff --git a/dm_control/viewer/gui/glfw_gui.py b/dm_control/viewer/gui/glfw_gui.py
new file mode 100644
index 00000000..1dd86b3d
--- /dev/null
+++ b/dm_control/viewer/gui/glfw_gui.py
@@ -0,0 +1,325 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Windowing system that uses GLFW library."""
+
+import functools
+from dm_control import _render
+from dm_control._render import glfw_renderer
+from dm_control.viewer import util
+from dm_control.viewer.gui import base
+from dm_control.viewer.gui import fullscreen_quad
+import glfw
+import numpy as np
+
+
+def _check_valid_backend(func):
+ """Decorator which checks that GLFW is being used for offscreen rendering."""
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ if _render.BACKEND != 'glfw':
+ raise RuntimeError(
+ '{func} may only be called if using GLFW for offscreen rendering, '
+ 'got `render.BACKEND={backend!r}`.'.format(
+ func=func, backend=_render.BACKEND))
+ return func(*args, **kwargs)
+ return wrapped_func
+
+
+class DoubleBufferedGlfwContext(glfw_renderer.GLFWContext):
+ """Custom context manager for the GLFW based GUI."""
+
+ def __init__(self, width, height, title):
+ self._title = title
+ super().__init__(max_width=width, max_height=height)
+
+ @_check_valid_backend
+ def _platform_init(self, width, height):
+ glfw.window_hint(glfw.SAMPLES, 4)
+ glfw.window_hint(glfw.VISIBLE, 1)
+ glfw.window_hint(glfw.DOUBLEBUFFER, 1)
+ self._context = glfw.create_window(width, height, self._title, None, None)
+ self._destroy_window = glfw.destroy_window
+
+ @property
+ def window(self):
+ return self._context
+
+
+class GlfwKeyboard(base.InputEventsProcessor):
+ """Glfw keyboard device handler.
+
+ Handles the keyboard input in a thread-safe way, and forwards the events
+ to the registered callbacks.
+
+ Attributes:
+ on_key: Observable subject triggered when a key event is triggered.
+ Expects a callback with signature: (key, scancode, activity, modifiers)
+ """
+
+ def __init__(self, context):
+ super().__init__()
+ with context.make_current() as ctx:
+ ctx.call(glfw.set_key_callback, context.window, self._handle_key_event)
+ self.on_key = util.QuietSet()
+
+ def _handle_key_event(self, window, key, scancode, activity, mods):
+ """Broadcasts the notification to registered listeners.
+
+ Args:
+ window: The window that received the event.
+ key: ID representing the key, a glfw.KEY_ constant.
+ scancode: The system-specific scancode of the key.
+ activity: glfw.PRESS, glfw.RELEASE or glfw.REPEAT.
+ mods: Bit field describing which modifier keys were held down, such as Alt
+ or Shift.
+ """
+ del window, scancode
+ self.add_event(self.on_key, key, activity, mods)
+
+
+class GlfwMouse(base.InputEventsProcessor):
+ """Glfw mouse device handler.
+
+ Handles the mouse input in a thread-safe way, forwarding the events to the
+ registered callbacks.
+
+ Attributes:
+ on_move: Observable subject triggered when a mouse move is detected.
+ Expects a callback with signature (position, translation).
+ on_click: Observable subject triggered when a mouse click is detected.
+ Expects a callback with signature (button, action, modifiers).
+ on_double_click: Observable subject triggered when a mouse double click is
+ detected. Expects a callback with signature (button, modifiers).
+ on_scroll: Observable subject triggered when a mouse scroll is detected.
+ Expects a callback with signature (scroll_value).
+ """
+
+ def __init__(self, context):
+ super().__init__()
+ self.on_move = util.QuietSet()
+ self.on_click = util.QuietSet()
+ self.on_double_click = util.QuietSet()
+ self.on_scroll = util.QuietSet()
+ self._double_click_detector = base.DoubleClickDetector()
+ with context.make_current() as ctx:
+ framebuffer_width, window_width = ctx.call(
+ self._glfw_setup, context.window)
+
+ self._scale = framebuffer_width * 1.0 / window_width
+ self._last_mouse_pos = np.zeros(2, int)
+
+ self._double_clicks = {}
+
+ def _glfw_setup(self, window):
+ glfw.set_cursor_pos_callback(window, self._handle_move)
+ glfw.set_mouse_button_callback(window, self._handle_button)
+ glfw.set_scroll_callback(window, self._handle_scroll)
+ framebuffer_width, _ = glfw.get_framebuffer_size(window)
+ window_width, _ = glfw.get_window_size(window)
+ return framebuffer_width, window_width
+
+ @property
+ def position(self):
+ return self._last_mouse_pos
+
+ def _handle_move(self, window, x, y):
+ """Mouse movement callback.
+
+ Args:
+ window: Window object from glfw.
+ x: Horizontal position of mouse, in pixels.
+ y: Vertical position of mouse, in pixels.
+ """
+ del window
+ position = np.array([x, y], int) * self._scale
+ delta = position - self._last_mouse_pos
+ self._last_mouse_pos = position
+ self.add_event(self.on_move, position, delta)
+
+ def _handle_button(self, window, button, act, mods):
+ """Mouse button click event handler."""
+ del window
+ self.add_event(self.on_click, button, act, mods)
+ if self._double_click_detector.process(button, act):
+ self.add_event(self.on_double_click, button, mods)
+
+ def _handle_scroll(self, window, x_offset, y_offset):
+ """Mouse wheel scroll event handler."""
+ del window, x_offset
+ self.add_event(self.on_scroll, y_offset)
+
+
+class GlfwWindow:
+ """A GLFW based application window.
+
+ Attributes:
+ on_files_drop: An observable subject, instance of util.QuietSet. Attached
+ listeners, callables taking one argument, will be invoked every time the
+ user drops files onto the window. The callable will be passed an iterable
+ with dropped file paths.
+ is_full_screen: Boolean, whether the window is currently full-screen.
+ """
+
+ def __init__(self, width, height, title, context=None):
+ """Instance initializer.
+
+ Args:
+ width: Initial window width, in pixels.
+ height: Initial window height, in pixels.
+ title: A string with a window title.
+ context: (Optional) A `render.GLFWContext` instance.
+
+ Raises:
+ RuntimeError: If GLFW initialization or window initialization fails.
+ """
+ super().__init__()
+ self._context = context or DoubleBufferedGlfwContext(width, height, title)
+
+ if not self._context.window:
+ raise RuntimeError('Failed to create window')
+
+ self._oldsize = None
+
+ with self._context.make_current() as ctx:
+ self._fullscreen_quad = ctx.call(self._glfw_setup, self._context.window)
+ self.on_files_drop = util.QuietSet()
+
+ self._keyboard = GlfwKeyboard(self._context)
+ self._mouse = GlfwMouse(self._context)
+
+ def _glfw_setup(self, window):
+ glfw.set_drop_callback(window, self._handle_file_drop)
+ return fullscreen_quad.FullscreenQuadRenderer()
+
+ @property
+ def shape(self):
+ """Returns a tuple with the shape of the window, (width, height)."""
+ with self._context.make_current() as ctx:
+ return ctx.call(glfw.get_framebuffer_size, self._context.window)
+
+ @property
+ def position(self):
+ """Returns a tuple with top-left window corner's coordinates, (x, y)."""
+ with self._context.make_current() as ctx:
+ return ctx.call(glfw.get_window_pos, self._context.window)
+
+ @property
+ def keyboard(self):
+ """Returns a GlfwKeyboard instance associated with the window."""
+ return self._keyboard
+
+ @property
+ def mouse(self):
+ """Returns a GlfwMouse instance associated with the window."""
+ return self._mouse
+
+ def set_title(self, title):
+ """Sets the window title.
+
+ Args:
+ title: A string, title of the window.
+ """
+ with self._context.make_current() as ctx:
+ ctx.call(glfw.set_window_title, self._context.window, title)
+
+ def set_full_screen(self, enable):
+ """Expands the main application window to full screen or minimizes it.
+
+ Args:
+ enable: Boolean flag, True expands the window to full-screen mode, False
+ minimizes it to its former size.
+ """
+ if enable == self.is_full_screen:
+ return
+
+ if enable:
+ self._oldsize = list(self.position) + list(self.shape)
+ def enable_full_screen(window):
+ display = glfw.get_primary_monitor()
+ videomode = glfw.get_video_mode(display)
+ glfw.set_window_monitor(window, display, 0, 0, videomode[0][0],
+ videomode[0][1], videomode[2])
+ with self._context.make_current() as ctx:
+ ctx.call(enable_full_screen, self._context.window)
+ else:
+ with self._context.make_current() as ctx:
+ ctx.call(glfw.set_window_monitor,
+ self._context.window, None, self._oldsize[0],
+ self._oldsize[1], self._oldsize[2],
+ self._oldsize[3], 0)
+ self._oldsize = None
+
+ def toggle_full_screen(self):
+ """Expands the main application window to full screen or minimizes it."""
+ show_full_screen = not self.is_full_screen
+ self.set_full_screen(show_full_screen)
+
+ @property
+ def is_full_screen(self):
+ return self._oldsize is not None
+
+ def free(self):
+ """Closes the deleted window."""
+ self.close()
+
+ def event_loop(self, tick_func):
+ """Runs the window's event loop.
+
+ This is a blocking call that won't exit until the window is closed.
+
+ Args:
+ tick_func: A callable, function to call every frame.
+ """
+ while not glfw.window_should_close(self._context.window):
+ self.update(tick_func)
+
+ def update(self, render_func):
+ """Updates the window and renders a new image.
+
+ Args:
+ render_func: A callable returning a 3D numpy array of bytes (np.uint8),
+ with dimensions (width, height, 3).
+ """
+ pixels = render_func()
+ with self._context.make_current() as ctx:
+ ctx.call(
+ self._update_gui_on_render_thread, self._context.window, pixels)
+ self._mouse.process_events()
+ self._keyboard.process_events()
+
+ def _update_gui_on_render_thread(self, window, pixels):
+ self._fullscreen_quad.render(pixels, self.shape)
+ glfw.swap_buffers(window)
+ glfw.poll_events()
+
+ def close(self):
+ """Closes the window and releases associated resources."""
+ if self._context is not None:
+ self._context.free()
+ self._context = None
+
+ def _handle_file_drop(self, window, paths):
+ """Handles events of user dropping files onto the window.
+
+ Args:
+ window: GLFW window handle (unused).
+ paths: An iterable with paths of files dropped onto the window.
+ """
+ del window
+ for listener in list(self.on_files_drop):
+ listener(paths)
+
+ def __del__(self):
+ self.free()
diff --git a/dm_control/viewer/gui/glfw_gui_test.py b/dm_control/viewer/gui/glfw_gui_test.py
new file mode 100644
index 00000000..7f5edaca
--- /dev/null
+++ b/dm_control/viewer/gui/glfw_gui_test.py
@@ -0,0 +1,234 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the GLFW based windowing system."""
+
+import contextlib
+
+from absl.testing import absltest
+from dm_control.viewer import user_input
+import mock
+import numpy as np
+
+
+_OPEN_GL_MOCK = mock.MagicMock()
+_GLFW_MOCK = mock.MagicMock()
+_MOCKED_MODULES = {
+ 'OpenGL': _OPEN_GL_MOCK,
+ 'OpenGL.GL': _OPEN_GL_MOCK,
+ 'glfw': _GLFW_MOCK,
+}
+with mock.patch.dict('sys.modules', _MOCKED_MODULES):
+ from dm_control.viewer.gui import glfw_gui # pylint: disable=g-import-not-at-top
+
+glfw_gui.base.GL = _OPEN_GL_MOCK
+glfw_gui.base.shaders = _OPEN_GL_MOCK
+glfw_gui.fullscreen_quad.GL = _OPEN_GL_MOCK
+glfw_gui.fullscreen_quad.shaders = _OPEN_GL_MOCK
+glfw_gui.glfw = _GLFW_MOCK
+
+# Pretend we are using GLFW for offscreen rendering so that the runtime backend
+# check will succeed.
+glfw_gui._render.BACKEND = 'glfw'
+
+_EPSILON = 1e-7
+
+
+class GlfwKeyboardTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ _GLFW_MOCK.reset_mock()
+ self.context = mock.MagicMock()
+ self.handler = glfw_gui.GlfwKeyboard(self.context)
+
+ self.events = [
+ (1, user_input.KEY_T, 't', user_input.PRESS, user_input.MOD_ALT),
+ (1, user_input.KEY_T, 't', user_input.RELEASE, user_input.MOD_ALT),
+ (1, user_input.KEY_A, 't', user_input.PRESS, 0)]
+
+ self.listener = mock.MagicMock()
+ self.handler.on_key += [self.listener]
+
+ def test_single_event(self):
+ self.handler._handle_key_event(*self.events[0])
+ self.handler.process_events()
+ self.listener.assert_called_once_with(user_input.KEY_T, user_input.PRESS,
+ user_input.MOD_ALT)
+
+ def test_sequence_of_events(self):
+ for event in self.events:
+ self.handler._handle_key_event(*event)
+ self.handler.process_events()
+ self.assertEqual(3, self.listener.call_count)
+ for event, call_args in zip(self.events, self.listener.call_args_list):
+ expected_event = tuple([event[1]] + list(event[-2:]))
+ self.assertEqual(expected_event, call_args[0])
+
+
+class FakePassthroughRenderingContext:
+
+ def __init__(self):
+ self.window = 0
+
+ def call(self, func, *args):
+ return func(*args)
+
+
+class GlfwMouseTest(absltest.TestCase):
+
+ @contextlib.contextmanager
+ def fake_make_current(self):
+ yield FakePassthroughRenderingContext()
+
+ def setUp(self):
+ super().setUp()
+ _GLFW_MOCK.reset_mock()
+ _GLFW_MOCK.get_framebuffer_size = mock.MagicMock(return_value=(256, 256))
+ _GLFW_MOCK.get_window_size = mock.MagicMock(return_value=(256, 256))
+ self.window = mock.MagicMock()
+ self.window.make_current = mock.MagicMock(
+ side_effect=self.fake_make_current)
+ self.handler = glfw_gui.GlfwMouse(self.window)
+
+ def test_moving_mouse(self):
+ def move_handler(position, translation):
+ self.position = position
+ self.translation = translation
+ self.new_position = [100, 100]
+ self.handler._last_mouse_pos = np.array([99, 101], int)
+ self.handler.on_move += move_handler
+
+ self.handler._handle_move(self.window, self.new_position[0],
+ self.new_position[1])
+ self.handler.process_events()
+ np.testing.assert_array_equal(self.new_position, self.position)
+ np.testing.assert_array_equal([1, -1], self.translation)
+
+ def test_button_click(self):
+ def click_handler(button, action, modifiers):
+ self.button = button
+ self.action = action
+ self.modifiers = modifiers
+ self.handler.on_click += click_handler
+
+ self.handler._handle_button(self.window, user_input.MOUSE_BUTTON_LEFT,
+ user_input.PRESS, user_input.MOD_SHIFT)
+ self.handler.process_events()
+
+ self.assertEqual(user_input.MOUSE_BUTTON_LEFT, self.button)
+ self.assertEqual(user_input.PRESS, self.action)
+ self.assertEqual(user_input.MOD_SHIFT, self.modifiers)
+
+ def test_scroll(self):
+ def scroll_handler(position):
+ self.position = position
+ self.handler.on_scroll += scroll_handler
+
+ x_value = 10
+ y_value = 20
+ self.handler._handle_scroll(self.window, x_value, y_value)
+ self.handler.process_events()
+ # x_value gets ignored, it's the y_value - the vertical scroll - we're
+ # interested in.
+ self.assertEqual(y_value, self.position)
+
+
+class GlfwWindowTest(absltest.TestCase):
+
+ WIDTH = 10
+ HEIGHT = 20
+
+ @contextlib.contextmanager
+ def fake_make_current(self):
+ yield FakePassthroughRenderingContext()
+
+ def setUp(self):
+ super().setUp()
+ _GLFW_MOCK.reset_mock()
+ _GLFW_MOCK.get_video_mode.return_value = (None, None, 60)
+ _GLFW_MOCK.get_framebuffer_size.return_value = (4, 5)
+ _GLFW_MOCK.get_window_size.return_value = (self.WIDTH, self.HEIGHT)
+ self.context = mock.MagicMock()
+ self.context.make_current = mock.MagicMock(
+ side_effect=self.fake_make_current)
+ self.window = glfw_gui.GlfwWindow(
+ self.WIDTH, self.HEIGHT, 'title', self.context)
+
+ def test_window_shape(self):
+ expected_shape = (self.WIDTH, self.HEIGHT)
+ _GLFW_MOCK.get_framebuffer_size.return_value = expected_shape
+ self.assertEqual(expected_shape, self.window.shape)
+
+ def test_window_position(self):
+ expected_position = (1, 2)
+ _GLFW_MOCK.get_window_pos.return_value = expected_position
+ self.assertEqual(expected_position, self.window.position)
+
+ def test_freeing_context(self):
+ self.window.close = mock.MagicMock()
+ self.window.free()
+ self.window.close.assert_called_once()
+
+ def test_close(self):
+ self.window.close()
+ self.context.free.assert_called_once()
+ self.assertIsNone(self.window._context)
+
+ def test_closing_window_that_has_already_been_closed(self):
+ self.window._context = None
+ self.window.close()
+ self.assertEqual(0, _GLFW_MOCK.destroy_window.call_count)
+
+ def test_file_drop(self):
+ self.expected_paths = ['path1', 'path2']
+ def callback(paths):
+ self.assertEqual(self.expected_paths, paths)
+
+ was_called_mock = mock.MagicMock()
+ self.window.on_files_drop += [callback, was_called_mock]
+ self.window._handle_file_drop('window_handle', self.expected_paths)
+ was_called_mock.assert_called_once()
+
+ def test_setting_title(self):
+ new_title = 'new_title'
+ self.window.set_title(new_title)
+ self.assertEqual(new_title, _GLFW_MOCK.set_window_title.call_args[0][1])
+
+ def test_enabling_full_screen(self):
+ full_screen_pos = (0, 0)
+ full_screen_size = (1, 2)
+ window_size = (3, 4)
+ window_pos = (5, 6)
+ reserved_value = 7
+ full_size_mode = 8
+
+ _GLFW_MOCK.get_framebuffer_size.return_value = window_size
+ _GLFW_MOCK.get_window_pos.return_value = window_pos
+ _GLFW_MOCK.get_video_mode.return_value = (
+ full_screen_size, reserved_value, full_size_mode)
+
+ self.window.set_full_screen(True)
+ _GLFW_MOCK.set_window_monitor.assert_called_once()
+
+ new_position = _GLFW_MOCK.set_window_monitor.call_args[0][2:4]
+ new_size = _GLFW_MOCK.set_window_monitor.call_args[0][4:6]
+ new_mode = _GLFW_MOCK.set_window_monitor.call_args[0][6]
+ self.assertEqual(full_screen_pos, new_position)
+ self.assertEqual(full_screen_size, new_size)
+ self.assertEqual(full_size_mode, new_mode)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/policy.gif b/dm_control/viewer/policy.gif
new file mode 100644
index 00000000..94238c65
Binary files /dev/null and b/dm_control/viewer/policy.gif differ
diff --git a/dm_control/viewer/renderer.py b/dm_control/viewer/renderer.py
new file mode 100644
index 00000000..e5d938e7
--- /dev/null
+++ b/dm_control/viewer/renderer.py
@@ -0,0 +1,686 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Renderer module."""
+
+import abc
+import contextlib
+import subprocess
+import sys
+
+from dm_control.mujoco import wrapper
+from dm_control.viewer import util
+import mujoco
+import numpy as np
+
+
+# Fixed camera -1 is the free (unfixed) camera, and each fixed camera has
+# a positive index in range (0, self._model.ncam).
+_FREE_CAMERA_INDEX = -1
+
+# Index used to distinguish when a camera isn't tracking any particular body.
+_NO_BODY_TRACKED_INDEX = -1
+
+# Index used to distinguish a non-existing or an invalid body.
+_INVALID_BODY_INDEX = -1
+
+# Zoom factor used when zooming in on the entire scene.
+_FULL_SCENE_ZOOM_FACTOR = 1.5
+
+# Default values for `MjvScene.flags`. These are the same as the defaults set by
+# `mjv_defaultScene`, except that we disable `mjRND_HAZE`.
+_DEFAULT_RENDER_FLAGS = np.zeros(mujoco.mjtRndFlag.mjNRNDFLAG, dtype=np.ubyte)
+_DEFAULT_RENDER_FLAGS[mujoco.mjtRndFlag.mjRND_SHADOW.value] = 1
+_DEFAULT_RENDER_FLAGS[mujoco.mjtRndFlag.mjRND_REFLECTION.value] = 1
+_DEFAULT_RENDER_FLAGS[mujoco.mjtRndFlag.mjRND_SKYBOX.value] = 1
+_DEFAULT_RENDER_FLAGS[mujoco.mjtRndFlag.mjRND_CULL_FACE.value] = 1
+
+# Font scale values.
+_DEFAULT_FONT_SCALE = mujoco.mjtFontScale.mjFONTSCALE_100
+_HIDPI_FONT_SCALE = mujoco.mjtFontScale.mjFONTSCALE_200
+
+
+def _has_high_dpi() -> bool:
+ """Returns True if the display is a high DPI display."""
+ if sys.platform == 'darwin':
+ # On macOS, we can use the system_profiler command to determine if the
+ # display is retina.
+ return subprocess.call(
+ 'system_profiler SPDisplaysDataType | grep -i "retina"',
+ shell=True,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL
+ ) == 0
+ # TODO(zakka): Figure out how to detect high DPI displays on Linux.
+ return False
+
+
+class BaseRenderer(metaclass=abc.ABCMeta):
+ """A base class for component-based Mujoco Renderers implementations.
+
+ Attributes:
+ components: A set of RendererComponent the renderer will render in addition
+ to rendering the physics scene. Being a QuietSet instance, it supports
+ adding and removing of components using += and -= operators.
+ screen_capture_components: Components that perform screen capture and need
+ a guarantee to be called when all other elements have been rendered.
+ """
+
+ def __init__(self):
+ """Instance initializer."""
+ self.components = util.QuietSet()
+ self.screen_capture_components = util.QuietSet()
+
+ def _render_components(self, context, viewport):
+ """Renders the components.
+
+ Args:
+ context: MjrContext instance.
+ viewport: Viewport instance.
+ """
+ for component in self.components:
+ component.render(context, viewport)
+ for component in self.screen_capture_components:
+ component.render(context, viewport)
+
+
+class Component(metaclass=abc.ABCMeta):
+ """Components are a way to introduce extra rendering content.
+
+ They are invoked after the main rendering pass, allowing to draw extra images
+ into the render buffer, such as overlays.
+ """
+
+ @abc.abstractmethod
+ def render(self, context, viewport):
+ """Renders the component.
+
+ Args:
+ context: MjrContext instance.
+ viewport: Viewport instance.
+ """
+ pass
+
+
+class NullRenderer:
+ """A stub off-screen renderer used when no other renderer is available."""
+
+ def __init__(self):
+ """Instance initializer."""
+ self._black = np.zeros((1, 1, 3), dtype=np.uint8)
+
+ def release(self):
+ pass
+
+ @property
+ def pixels(self):
+ """Returns a black pixel map."""
+ return self._black
+
+
+class OffScreenRenderer(BaseRenderer):
+ """A Mujoco renderer that renders to an off-screen surface."""
+
+ def __init__(self, model, surface):
+ """Instance initializer.
+
+ Args:
+ model: instance of MjModel.
+ surface: instance of dm_control.render.BaseContext.
+ """
+ super().__init__()
+ self._surface = surface
+ self._surface.increment_refcount()
+ self._model = model
+ self._mujoco_context = None
+
+ self._prev_viewport = np.ones(2)
+ self._rgb_buffer = np.empty((1, 1, 3), dtype=np.uint8)
+ self._pixels = np.zeros((1, 1, 3), dtype=np.uint8)
+
+ def render(self, viewport, scene):
+ """Renders the scene to the specified viewport.
+
+ Args:
+ viewport: Instance of Viewport.
+ scene: Instance of MjvScene.
+ Returns:
+ A 3-dimensional array of shape (viewport.width, viewport.height, 3),
+ with the contents of the front buffer.
+ """
+ if not np.array_equal(self._prev_viewport, viewport.dimensions):
+ self._prev_viewport = viewport.dimensions
+ if self._mujoco_context:
+ self._mujoco_context.free()
+ self._mujoco_context = None
+ if not self._mujoco_context:
+ # Ensure that MuJoCo's offscreen framebuffer is large enough to
+ # accommodate the viewport.
+ new_offwidth = max(self._model.vis.global_.offwidth, viewport.width)
+ new_offheight = max(self._model.vis.global_.offheight, viewport.height)
+ self._model.vis.global_.offwidth = new_offwidth
+ self._model.vis.global_.offheight = new_offheight
+ font_scale = _HIDPI_FONT_SCALE if _has_high_dpi() else _DEFAULT_FONT_SCALE
+ self._mujoco_context = wrapper.MjrContext(
+ model=self._model,
+ gl_context=self._surface,
+ font_scale=font_scale)
+ self._rgb_buffer = np.empty(
+ (viewport.height, viewport.width, 3), dtype=np.uint8)
+
+ with self._surface.make_current() as ctx:
+ ctx.call(self._render_on_gl_thread, viewport, scene)
+ self._pixels = self._surface.to_pixels(self._rgb_buffer)
+
+ def _render_on_gl_thread(self, viewport, scene):
+ mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_OFFSCREEN,
+ self._mujoco_context.ptr)
+ mujoco.mjr_render(viewport.mujoco_rect, scene.ptr, self._mujoco_context.ptr)
+ self._render_components(self._mujoco_context, viewport)
+ mujoco.mjr_readPixels(self._rgb_buffer, None, viewport.mujoco_rect,
+ self._mujoco_context.ptr)
+
+ def release(self):
+ """Releases the render context and related resources."""
+ if self._mujoco_context:
+ self._mujoco_context.free()
+ self._mujoco_context = None
+ if self._surface:
+ self._surface.decrement_refcount()
+ self._surface.free()
+ self._surface = None
+
+ @property
+ def pixels(self):
+ """Returns the rendered image."""
+ return self._pixels
+
+
+class Perturbation:
+ """A proxy that allows to move a scene object."""
+
+ def __init__(self, body_id, model, data, scene):
+ """Instance initializer.
+
+ Args:
+ body_id: A positive integer, ID of the body to manipulate.
+ model: MjModel instance.
+ data: MjData instance.
+ scene: MjvScene instance.
+ """
+ self._body_id = body_id
+ self._model = model
+ self._data = data
+ self._scene = scene
+ self._action = mujoco.mjtMouse.mjMOUSE_NONE
+
+ self._perturb = wrapper.MjvPerturb()
+ self._perturb.select = self._body_id
+ self._perturb.active = 0
+
+ mujoco.mjv_initPerturb(self._model.ptr, self._data.ptr, self._scene.ptr,
+ self._perturb.ptr)
+
+ def start_move(self, action, grab_pos):
+ """Starts a movement action."""
+ if action is None or grab_pos is None:
+ return
+
+ body_pos = self._data.xpos[self._body_id]
+ body_mat = self._data.xmat[self._body_id].reshape(3, 3)
+ grab_local_pos = body_mat.T.dot(grab_pos - body_pos)
+ self._perturb.localpos[:] = grab_local_pos
+
+ mujoco.mjv_initPerturb(self._model.ptr, self._data.ptr, self._scene.ptr,
+ self._perturb.ptr)
+ self._action = action
+ move_actions = np.array(
+ [mujoco.mjtMouse.mjMOUSE_MOVE_V, mujoco.mjtMouse.mjMOUSE_MOVE_H])
+ if any(move_actions == action):
+ self._perturb.active = mujoco.mjtPertBit.mjPERT_TRANSLATE
+ else:
+ self._perturb.active = mujoco.mjtPertBit.mjPERT_ROTATE
+
+ def tick_move(self, viewport_offset):
+ """Transforms object's location/rotation by the specified amount."""
+ if self._action and self._action != mujoco.mjtMouse.mjMOUSE_NONE:
+ mujoco.mjv_movePerturb(self._model.ptr, self._data.ptr, self._action,
+ viewport_offset[0], viewport_offset[1],
+ self._scene.ptr, self._perturb.ptr)
+
+ def end_move(self):
+ """Ends a movement operation."""
+ self._action = mujoco.mjtMouse.mjMOUSE_NONE
+ self._perturb.active = 0
+
+ @contextlib.contextmanager
+ def apply(self, paused):
+ """Applies the modifications introduced by performing the move operation."""
+ mujoco.mjv_applyPerturbPose(self._model.ptr, self._data.ptr,
+ self._perturb.ptr, int(paused))
+ if not paused:
+ mujoco.mjv_applyPerturbForce(self._model.ptr, self._data.ptr,
+ self._perturb.ptr)
+ yield
+ self._data.xfrc_applied[self._perturb.select] = 0
+
+ @property
+ def ptr(self):
+ """Returns the underlying Mujoco Perturbation object."""
+ return self._perturb.ptr
+
+ @property
+ def body_id(self):
+ """A positive integer, ID of the manipulated body."""
+ return self._body_id
+
+
+class NullPerturbation:
+ """An empty perturbation.
+
+ A null-object pattern, used to avoid cumbersome if clauses.
+ """
+
+ @contextlib.contextmanager
+ def apply(self, paused):
+ """Activates/deactivates the null context."""
+ del paused
+ yield
+
+ @property
+ def ptr(self):
+ """Returns None, because this class represents an empty perturbation."""
+ return None
+
+
+class RenderSettings:
+ """Renderer settings."""
+
+ def __init__(self):
+ self._visualization_options = wrapper.MjvOption()
+ self._stereo_mode = mujoco.mjtStereo.mjSTEREO_NONE
+ self._render_flags = _DEFAULT_RENDER_FLAGS
+
+ @property
+ def visualization(self):
+ """Returns scene visualization options."""
+ return self._visualization_options
+
+ @property
+ def render_flags(self):
+ """Returns the render flags."""
+ return self._render_flags
+
+ @property
+ def visualization_flags(self):
+ """Returns scene visualization flags."""
+ return self._visualization_options.flags
+
+ @property
+ def geom_groups(self):
+ """Returns geom groups visibility flags."""
+ return self._visualization_options.geomgroup
+
+ @property
+ def site_groups(self):
+ """Returns site groups visibility flags."""
+ return self._visualization_options.sitegroup
+
+ def apply_settings(self, scene):
+ """Applies settings to the specified scene.
+
+ Args:
+ scene: Instance of MjvScene.
+ """
+ scene.stereo = self._stereo_mode
+ scene.flags[:] = self._render_flags[:]
+
+ def toggle_rendering_flag(self, flag_index):
+ """Toggles the specified rendering flag."""
+ self._render_flags[flag_index] = not self._render_flags[flag_index]
+
+ def toggle_visualization_flag(self, flag_index):
+ """Toggles the specified visualization flag."""
+ self._visualization_options.flags[flag_index] = (
+ not self._visualization_options.flags[flag_index])
+
+ def toggle_geom_group(self, group_index):
+ """Toggles the specified geom group visible or not."""
+ self._visualization_options.geomgroup[group_index] = (
+ not self._visualization_options.geomgroup[group_index])
+
+ def toggle_site_group(self, group_index):
+ """Toggles the specified site group visible or not."""
+ self._visualization_options.sitegroup[group_index] = (
+ not self._visualization_options.sitegroup[group_index])
+
+ def toggle_stereo_buffering(self):
+ """Toggles the double buffering mode on/off."""
+ if self._stereo_mode == mujoco.mjtStereo.mjSTEREO_NONE:
+ self._stereo_mode = mujoco.mjtStereo.mjSTEREO_QUADBUFFERED
+ else:
+ self._stereo_mode = mujoco.mjtStereo.mjSTEREO_NONE
+
+ def select_next_rendering_mode(self):
+ """Cycles to the next rendering mode."""
+ self._visualization_options.frame = (
+ (self._visualization_options.frame + 1) % mujoco.mjtFrame.mjNFRAME)
+
+ def select_prev_rendering_mode(self):
+ """Cycles to the previous rendering mode."""
+ self._visualization_options.frame = (
+ (self._visualization_options.frame - 1) % mujoco.mjtFrame.mjNFRAME)
+
+ def select_next_labeling_mode(self):
+ """Cycles to the next scene object labeling mode."""
+ self._visualization_options.label = (
+ (self._visualization_options.label + 1) % mujoco.mjtLabel.mjNLABEL)
+
+ def select_prev_labeling_mode(self):
+ """Cycles to the previous scene object labeling mode."""
+ self._visualization_options.label = (
+ (self._visualization_options.label - 1) % mujoco.mjtLabel.mjNLABEL)
+
+
+class SceneCamera:
+ """A camera used to navigate around and render the scene."""
+
+ def __init__(self,
+ model,
+ data,
+ options,
+ settings=None,
+ zoom_factor=_FULL_SCENE_ZOOM_FACTOR,
+ scene_callback=None):
+ """Instance initializer.
+
+ Args:
+ model: MjModel instance.
+ data: MjData instance.
+ options: RenderSettings instance.
+ settings: Optional, internal camera settings obtained from another
+ SceneCamera instance using 'settings' property.
+ zoom_factor: The initial zoom factor for zooming into the scene.
+ scene_callback: Scene callback.
+ This is a callable of the form: `my_callable(MjModel, MjData, MjvScene)`
+ that gets applied to every rendered scene.
+ """
+ # Design notes:
+ # We need to recreate the camera for each new model, because each model
+ # defines different fixed cameras and objects to track, and therefore
+ # severely the parameters of this class.
+ self._scene = wrapper.MjvScene(model)
+ self._data = data
+ self._model = model
+ self._options = options
+
+ self._camera = wrapper.MjvCamera()
+ self.set_freelook_mode()
+ self._zoom_factor = zoom_factor
+ self._scene_callback = scene_callback
+
+ if settings is not None:
+ self._settings = settings
+ self.settings = settings
+ else:
+ self._settings = self._camera
+
+ def set_freelook_mode(self):
+ """Enables 6 degrees of freedom of movement for the camera."""
+ self._camera.trackbodyid = _NO_BODY_TRACKED_INDEX
+ self._camera.fixedcamid = _FREE_CAMERA_INDEX
+ self._camera.type_ = mujoco.mjtCamera.mjCAMERA_FREE
+ mujoco.mjv_defaultFreeCamera(self._model.ptr, self._camera.ptr)
+
+ def set_tracking_mode(self, body_id):
+ """Latches the camera onto the specified body.
+
+ Leaves the user only 3 degrees of freedom to rotate the camera.
+
+ Args:
+ body_id: A positive integer, ID of the body to track.
+ """
+ if body_id < 0:
+ return
+ self._camera.trackbodyid = body_id
+ self._camera.fixedcamid = _FREE_CAMERA_INDEX
+ self._camera.type_ = mujoco.mjtCamera.mjCAMERA_TRACKING
+
+ def set_fixed_mode(self, fixed_camera_id):
+ """Fixes the camera in a pre-defined position, taking away all DOF.
+
+ Args:
+ fixed_camera_id: A positive integer, Id of a fixed camera defined in the
+ scene.
+ """
+ if fixed_camera_id < 0:
+ return
+ self._camera.trackbodyid = _NO_BODY_TRACKED_INDEX
+ self._camera.fixedcamid = fixed_camera_id
+ self._camera.type_ = mujoco.mjtCamera.mjCAMERA_FIXED
+
+ def look_at(self, position, distance):
+ """Positions the camera so that it's focused on the specified point."""
+ self._camera.lookat[:] = position
+ self._camera.distance = distance
+
+ def move(self, action, viewport_offset):
+ """Moves the camera around the scene."""
+ # Not checking the validity of arguments on purpose. This method is designed
+ # to be called very often, so in order to avoid the overhead, all arguments
+ # are assumed to be valid.
+ mujoco.mjv_moveCamera(self._model.ptr, action, viewport_offset[0],
+ viewport_offset[1], self._scene.ptr, self._camera.ptr)
+
+ def new_perturbation(self, body_id):
+ """Creates a proxy that allows to manipulate the specified object."""
+ return Perturbation(body_id, self._model, self._data, self._scene)
+
+ def raycast(self, viewport, screen_pos):
+ """Shoots a ray from the specified viewport position into the scene."""
+ if not self.is_initialized:
+ return -1, None
+ viewport_pos = viewport.screen_to_inverse_viewport(screen_pos)
+ grab_world_pos = np.empty(3, dtype=np.double)
+ selected_geom_id_arr = np.intc([-1])
+ selected_flex_id_arr = np.intc([-1])
+ selected_skin_id_arr = np.intc([-1])
+ selected_body_id = mujoco.mjv_select(
+ self._model.ptr,
+ self._data.ptr,
+ self._options.visualization.ptr,
+ viewport.aspect_ratio,
+ viewport_pos[0],
+ viewport_pos[1],
+ self._scene.ptr,
+ grab_world_pos,
+ selected_geom_id_arr,
+ selected_flex_id_arr,
+ selected_skin_id_arr,
+ )
+ del (
+ selected_geom_id_arr,
+ selected_skin_id_arr,
+ selected_flex_id_arr,
+ ) # Unused.
+ if selected_body_id < 0:
+ selected_body_id = _INVALID_BODY_INDEX
+ grab_world_pos = None
+ return selected_body_id, grab_world_pos
+
+ def render(self, perturbation=None):
+ """Renders the scene form this camera's perspective.
+
+ Args:
+ perturbation: (Optional), instance of Perturbation.
+ Returns:
+ Rendered scene, instance of MjvScene.
+ """
+ perturb_to_render = perturbation.ptr if perturbation else None
+ mujoco.mjv_updateScene(self._model.ptr, self._data.ptr,
+ self._options.visualization.ptr, perturb_to_render,
+ self._camera.ptr, mujoco.mjtCatBit.mjCAT_ALL,
+ self._scene.ptr)
+
+ # Apply callback if defined.
+ if self._scene_callback is not None:
+ self._scene_callback(self._model, self._data, self._scene)
+ return self._scene
+
+ def zoom_to_scene(self):
+ """Zooms in on the entire scene."""
+ self.look_at(self._model.stat.center[:],
+ self._zoom_factor * self._model.stat.extent)
+
+ self.settings = self._settings
+
+ @property
+ def transform(self):
+ """Returns a tuple with camera transform.
+
+ The transform comes in form: (3x3 rotation mtx, 3-component position).
+ """
+ pos = np.zeros(3)
+ forward = np.zeros(3)
+ up = np.zeros(3)
+ for i in range(3):
+ forward[i] = self._scene.camera[0].forward[i]
+ up[i] = self._scene.camera[0].up[i]
+ pos[i] = (self._scene.camera[0].pos[i] + self._scene.camera[1].pos[i]) / 2
+ right = np.cross(forward, up)
+ return np.array([right, up, forward]), pos
+
+ @property
+ def settings(self):
+ """Returns internal camera settings."""
+ return self._camera
+
+ @settings.setter
+ def settings(self, value):
+ """Restores the camera settings."""
+ self._camera.type_ = value.type_
+ self._camera.fixedcamid = value.fixedcamid
+ self._camera.trackbodyid = value.trackbodyid
+ self._camera.lookat[:] = value.lookat[:]
+ self._camera.distance = value.distance
+ self._camera.azimuth = value.azimuth
+ self._camera.elevation = value.elevation
+
+ @property
+ def name(self):
+ """Name of the active camera."""
+ if self._camera.type_ == mujoco.mjtCamera.mjCAMERA_TRACKING:
+ body_name = self._model.id2name(self._camera.trackbodyid, 'body')
+ if body_name:
+ return 'Tracking body "%s"' % body_name
+ else:
+ return 'Tracking body id %d' % self._camera.trackbodyid
+ elif self._camera.type_ == mujoco.mjtCamera.mjCAMERA_FIXED:
+ camera_name = self._model.id2name(self._camera.fixedcamid, 'camera')
+ if camera_name:
+ return str(camera_name)
+ else:
+ return str(self._camera.fixedcamid)
+ else:
+ return 'Free'
+
+ @property
+ def mode(self):
+ """Index of the mode the camera is currently in."""
+ return self._camera.type_
+
+ @property
+ def is_initialized(self):
+ """Returns True if camera is properly initialized."""
+ if not self._scene:
+ return False
+ frustum_near = self._scene.camera[0].frustum_near
+ frustum_far = self._scene.camera[0].frustum_far
+ return frustum_near > 0 and frustum_near < frustum_far
+
+
+class Viewport:
+ """Render viewport."""
+
+ def __init__(self, width=1, height=1):
+ """Instance initializer.
+
+ Args:
+ width: Viewport width, in pixels.
+ height: Viewport height, in pixels.
+ """
+ self._screen_size = mujoco.MjrRect(0, 0, width, height)
+
+ def set_size(self, width, height):
+ """Changes the viewport size.
+
+ Args:
+ width: Viewport width, in pixels.
+ height: Viewport height, in pixels.
+ """
+ self._screen_size.width = width
+ self._screen_size.height = height
+
+ def screen_to_viewport(self, screen_coordinates):
+ """Converts screen coordinates to viewport coordinates.
+
+ Args:
+ screen_coordinates: 2-component tuple, with components being integral
+ numbers in range defined by the screen/window resolution.
+ Returns:
+ A 2-component tuple, with components being floating point values in range
+ [0, 1].
+ """
+ x = screen_coordinates[0] / self._screen_size.width
+ y = screen_coordinates[1] / self._screen_size.height
+ return np.array([x, y], np.float32)
+
+ def screen_to_inverse_viewport(self, screen_coordinates):
+ """Converts screen coordinates to viewport coordinates flipped vertically.
+
+ Args:
+ screen_coordinates: 2-component tuple, with components being integral
+ numbers in range defined by the screen/window resolution.
+ Returns:
+ A 2-component tuple, with components being floating point values in range
+ [0, 1]. The height component value will be flipped, with 1 at the top, and
+ 0 at the bottom of the viewport.
+ """
+ x = screen_coordinates[0] / self._screen_size.width
+ y = 1. - (screen_coordinates[1] / self._screen_size.height)
+ return np.array([x, y], np.float32)
+
+ @property
+ def aspect_ratio(self):
+ return self._screen_size.width / self._screen_size.height
+
+ @property
+ def mujoco_rect(self):
+ """Instance of MJRRECT with viewport dimensions."""
+ return self._screen_size
+
+ @property
+ def dimensions(self):
+ """Viewport dimensions in form of a 2-component vector."""
+ return np.asarray([self._screen_size.width, self._screen_size.height])
+
+ @property
+ def width(self):
+ """Viewport width."""
+ return self._screen_size.width
+
+ @property
+ def height(self):
+ """Viewport height."""
+ return self._screen_size.height
diff --git a/dm_control/viewer/renderer_test.py b/dm_control/viewer/renderer_test.py
new file mode 100644
index 00000000..2fb0ab27
--- /dev/null
+++ b/dm_control/viewer/renderer_test.py
@@ -0,0 +1,549 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests of the renderer module."""
+
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.viewer import renderer
+import mock
+import mujoco
+import numpy as np
+
+
+renderer.mujoco = mock.MagicMock()
+
+_SCREEN_SIZE = mujoco.MjrRect(0, 0, 320, 240)
+
+
+class BaseRendererTest(absltest.TestCase):
+
+ class MockRenderer(renderer.BaseRenderer):
+ pass
+
+ class MockRenderComponent(renderer.Component):
+
+ counter = 0
+
+ def __init__(self):
+ self._call_order = -1
+
+ def render(self, context, viewport):
+ self._call_order = BaseRendererTest.MockRenderComponent.counter
+ BaseRendererTest.MockRenderComponent.counter += 1
+
+ @property
+ def call_order(self):
+ return self._call_order
+
+ def setUp(self):
+ super().setUp()
+ self.renderer = BaseRendererTest.MockRenderer()
+ self.context = mock.MagicMock()
+ self.viewport = mock.MagicMock()
+
+ def test_rendering_components(self):
+ regular_component = BaseRendererTest.MockRenderComponent()
+ screen_capture_component = BaseRendererTest.MockRenderComponent()
+ self.renderer.components += regular_component
+ self.renderer.screen_capture_components += screen_capture_component
+ self.renderer._render_components(self.context, self.viewport)
+ self.assertEqual(0, regular_component.call_order)
+ self.assertEqual(1, screen_capture_component.call_order)
+
+
+@absltest.skip('b/222664582')
+class OffScreenRendererTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.model = mock.MagicMock()
+ self.model.vis.global_.offwidth = _SCREEN_SIZE.width
+ self.model.vis.global_.offheight = _SCREEN_SIZE.height
+
+ self.surface = mock.MagicMock()
+ self.renderer = renderer.OffScreenRenderer(self.model, self.surface)
+ self.renderer._mujoco_context = mock.MagicMock()
+
+ self.viewport = mock.MagicMock()
+ self.scene = mock.MagicMock()
+
+ self.viewport.width = 3
+ self.viewport.height = 3
+ self.viewport.dimensions = np.array([3, 3])
+
+ def test_render_context_initialization(self):
+ self.renderer._mujoco_context = None
+ self.renderer.render(self.viewport, self.scene)
+ self.assertIsNotNone(self.renderer._mujoco_context)
+
+ def test_resizing_pixel_buffer_to_viewport_size(self):
+ self.renderer.render(self.viewport, self.scene)
+ self.assertEqual((self.viewport.width, self.viewport.height, 3),
+ self.renderer._rgb_buffer.shape)
+
+ def test_rendering_components(self):
+ regular_component = mock.MagicMock()
+ screen_capture_components = mock.MagicMock()
+ self.renderer.components += [regular_component]
+ self.renderer.screen_capture_components += [screen_capture_components]
+ self.renderer._render_on_gl_thread(self.viewport, self.scene)
+ regular_component.render.assert_called_once()
+ screen_capture_components.render.assert_called_once()
+
+
+@absltest.skip('b/222664582')
+class PerturbationTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.model = mock.MagicMock()
+ self.data = mock.MagicMock()
+ self.scene = mock.MagicMock()
+ self.valid_pos = np.array([1, 2, 3])
+
+ self.body_id = 0
+ self.data.xpos = [np.array([0, 1, 2])]
+ self.data.xmat = [np.identity(3)]
+
+ self.perturbation = renderer.Perturbation(
+ self.body_id, self.model, self.data, self.scene)
+
+ renderer.mujoco.reset_mock()
+
+ def test_start_params_validation(self):
+ self.perturbation.start_move(None, self.valid_pos)
+ self.assertEqual(0, renderer.mujoco.mjv_initPerturb.call_count)
+ self.assertEqual(enums.mjtMouse.mjMOUSE_NONE, self.perturbation._action)
+
+ self.perturbation.start_move(enums.mjtMouse.mjMOUSE_MOVE_V, None)
+ self.assertEqual(0, renderer.mujoco.mjv_initPerturb.call_count)
+ self.assertEqual(enums.mjtMouse.mjMOUSE_NONE, self.perturbation._action)
+
+ def test_starting_an_operation(self):
+ self.perturbation.start_move(enums.mjtMouse.mjMOUSE_MOVE_V, self.valid_pos)
+ renderer.mujoco.mjv_initPerturb.assert_called_once()
+ self.assertEqual(enums.mjtMouse.mjMOUSE_MOVE_V, self.perturbation._action)
+
+ def test_starting_translation(self):
+ self.perturbation.start_move(enums.mjtMouse.mjMOUSE_MOVE_V, self.valid_pos)
+ self.assertEqual(
+ enums.mjtPertBit.mjPERT_TRANSLATE, self.perturbation._perturb.active)
+
+ def test_starting_rotation(self):
+ self.perturbation.start_move(enums.mjtMouse.mjMOUSE_ROTATE_V,
+ self.valid_pos)
+ self.assertEqual(
+ enums.mjtPertBit.mjPERT_ROTATE, self.perturbation._perturb.active)
+
+ def test_starting_grip_transform(self):
+ self.perturbation.start_move(enums.mjtMouse.mjMOUSE_MOVE_V, self.valid_pos)
+ np.testing.assert_array_equal(
+ [1, 1, 1], self.perturbation._perturb.localpos)
+
+ def test_ticking_operation(self):
+ self.perturbation._action = enums.mjtMouse.mjMOUSE_MOVE_V
+ self.perturbation.tick_move([.1, .2])
+ renderer.mujoco.mjv_movePerturb.assert_called_once()
+ action, dx, dy = renderer.mujoco.mjv_movePerturb.call_args[0][2:5]
+ self.assertEqual(self.perturbation._action, action)
+ self.assertEqual(.1, dx)
+ self.assertEqual(.2, dy)
+
+ def test_ticking_stopped_operation_yields_no_results(self):
+ self.perturbation._action = None
+ self.perturbation.tick_move([.1, .2])
+ self.assertEqual(0, renderer.mujoco.mjv_movePerturb.call_count)
+
+ self.perturbation._action = enums.mjtMouse.mjMOUSE_NONE
+ self.perturbation.tick_move([.1, .2])
+ self.assertEqual(0, renderer.mujoco.mjv_movePerturb.call_count)
+
+ def test_stopping_operation(self):
+ self.perturbation._action = enums.mjtMouse.mjMOUSE_MOVE_V
+ self.perturbation._perturb.active = enums.mjtPertBit.mjPERT_TRANSLATE
+ self.perturbation.end_move()
+ self.assertEqual(enums.mjtMouse.mjMOUSE_NONE, self.perturbation._action)
+ self.assertEqual(0, self.perturbation._perturb.active)
+
+ def test_applying_operation_results_while_not_paused(self):
+ with self.perturbation.apply(False):
+ renderer.mujoco.mjv_applyPerturbPose.assert_called_once()
+ self.assertEqual(0, renderer.mujoco.mjv_applyPerturbPose.call_args[0][3])
+ renderer.mujoco.mjv_applyPerturbForce.assert_called_once()
+
+ def test_applying_operation_results_while_paused(self):
+ with self.perturbation.apply(True):
+ renderer.mujoco.mjv_applyPerturbPose.assert_called_once()
+ self.assertEqual(1, renderer.mujoco.mjv_applyPerturbPose.call_args[0][3])
+ self.assertEqual(0, renderer.mujoco.mjv_applyPerturbForce.call_count)
+
+ def test_clearing_applied_forces_after_appling_operation(self):
+ self.data.xfrc_applied = np.zeros(1)
+ with self.perturbation.apply(True):
+ # At this point the simulation will calculate forces to apply and assign
+ # them to a proper MjvData structure field, as we're doing below.
+ self.data.xfrc_applied[self.body_id] = 1
+
+ # While exiting, the context clears that information.
+ self.assertEqual(0, self.data.xfrc_applied[self.body_id])
+
+
+@absltest.skip('b/222664582')
+class RenderSettingsTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.settings = renderer.RenderSettings()
+ self.scene = wrapper.MjvScene()
+
+ def test_applying_settings(self):
+ self.settings._stereo_mode = 5
+ self.settings._render_flags[:] = np.arange(len(self.settings._render_flags))
+ self.settings.apply_settings(self.scene)
+ self.assertEqual(self.settings._stereo_mode, self.scene.stereo)
+ np.testing.assert_array_equal(self.settings._render_flags, self.scene.flags)
+
+ def test_toggle_rendering_flag(self):
+ self.settings._render_flags[0] = 1
+ self.settings.toggle_rendering_flag(0)
+ self.assertEqual(0, self.settings._render_flags[0])
+ self.settings.toggle_rendering_flag(0)
+ self.assertEqual(1, self.settings._render_flags[0])
+
+ def test_toggle_visualization_flag(self):
+ self.settings._visualization_options.flags[0] = 1
+ self.settings.toggle_visualization_flag(0)
+ self.assertEqual(0, self.settings._visualization_options.flags[0])
+ self.settings.toggle_visualization_flag(0)
+ self.assertEqual(1, self.settings._visualization_options.flags[0])
+
+ def test_toggle_geom_group(self):
+ self.settings._visualization_options.geomgroup[0] = 1
+ self.settings.toggle_geom_group(0)
+ self.assertEqual(0, self.settings._visualization_options.geomgroup[0])
+ self.settings.toggle_geom_group(0)
+ self.assertEqual(1, self.settings._visualization_options.geomgroup[0])
+
+ def test_toggle_site_group(self):
+ self.settings._visualization_options.sitegroup[0] = 1
+ self.settings.toggle_site_group(0)
+ self.assertEqual(0, self.settings._visualization_options.sitegroup[0])
+ self.settings.toggle_site_group(0)
+ self.assertEqual(1, self.settings._visualization_options.sitegroup[0])
+
+ def test_toggle_stereo_buffering(self):
+ self.settings.toggle_stereo_buffering()
+ self.assertEqual(enums.mjtStereo.mjSTEREO_QUADBUFFERED,
+ self.settings._stereo_mode)
+ self.settings.toggle_stereo_buffering()
+ self.assertEqual(enums.mjtStereo.mjSTEREO_NONE,
+ self.settings._stereo_mode)
+
+ def test_cycling_forward_through_render_modes(self):
+ self.settings._visualization_options.frame = 0
+ self.settings.select_next_rendering_mode()
+ self.assertEqual(1, self.settings._visualization_options.frame)
+
+ self.settings._visualization_options.frame = enums.mjtFrame.mjNFRAME - 1
+ self.settings.select_next_rendering_mode()
+ self.assertEqual(0, self.settings._visualization_options.frame)
+
+ def test_cycling_backward_through_render_modes(self):
+ self.settings._visualization_options.frame = 0
+ self.settings.select_prev_rendering_mode()
+ self.assertEqual(enums.mjtFrame.mjNFRAME - 1,
+ self.settings._visualization_options.frame)
+
+ self.settings._visualization_options.frame = 1
+ self.settings.select_prev_rendering_mode()
+ self.assertEqual(0, self.settings._visualization_options.frame)
+
+ def test_cycling_forward_through_labeling_modes(self):
+ self.settings._visualization_options.label = 0
+ self.settings.select_next_labeling_mode()
+ self.assertEqual(1, self.settings._visualization_options.label)
+
+ self.settings._visualization_options.label = enums.mjtLabel.mjNLABEL - 1
+ self.settings.select_next_labeling_mode()
+ self.assertEqual(0, self.settings._visualization_options.label)
+
+ def test_cycling_backward_through_labeling_modes(self):
+ self.settings._visualization_options.label = 0
+ self.settings.select_prev_labeling_mode()
+ self.assertEqual(enums.mjtLabel.mjNLABEL - 1,
+ self.settings._visualization_options.label)
+
+ self.settings._visualization_options.label = 1
+ self.settings.select_prev_labeling_mode()
+ self.assertEqual(0, self.settings._visualization_options.label)
+
+
+@absltest.skip('b/222664582')
+class SceneCameraTest(parameterized.TestCase):
+
+ @mock.patch.object(renderer.wrapper.core,
+ '_estimate_max_renderable_geoms',
+ return_value=1000)
+ @mock.patch.object(renderer.wrapper.core.mujoco, 'MjvScene')
+ def setUp(self, mock_make_scene, _):
+ super().setUp()
+ self.model = mock.MagicMock()
+ self.data = mock.MagicMock()
+ self.options = mock.MagicMock()
+ self.camera = renderer.SceneCamera(self.model, self.data, self.options)
+ mock_make_scene.assert_called_once()
+
+ def test_freelook_mode(self):
+ self.camera.set_freelook_mode()
+ self.assertEqual(-1, self.camera._camera.trackbodyid)
+ self.assertEqual(-1, self.camera._camera.fixedcamid)
+ self.assertEqual(enums.mjtCamera.mjCAMERA_FREE, self.camera._camera.type_)
+ self.assertEqual('Free', self.camera.name)
+
+ def test_tracking_mode(self):
+ body_id = 5
+ self.camera.set_tracking_mode(body_id)
+ self.assertEqual(body_id, self.camera._camera.trackbodyid)
+ self.assertEqual(-1, self.camera._camera.fixedcamid)
+ self.assertEqual(enums.mjtCamera.mjCAMERA_TRACKING,
+ self.camera._camera.type_)
+
+ self.model.id2name = mock.MagicMock(return_value='body_name')
+ self.assertEqual('Tracking body "body_name"', self.camera.name)
+
+ def test_fixed_mode(self):
+ camera_id = 5
+ self.camera.set_fixed_mode(camera_id)
+ self.assertEqual(-1, self.camera._camera.trackbodyid)
+ self.assertEqual(camera_id, self.camera._camera.fixedcamid)
+ self.assertEqual(enums.mjtCamera.mjCAMERA_FIXED,
+ self.camera._camera.type_)
+
+ self.model.id2name = mock.MagicMock(return_value='camera_name')
+ self.assertEqual('camera_name', self.camera.name)
+
+ def test_look_at(self):
+ target_pos = [10, 20, 30]
+ distance = 5.
+ self.camera.look_at(target_pos, distance)
+ np.testing.assert_array_equal(target_pos, self.camera._camera.lookat)
+ np.testing.assert_array_equal(distance, self.camera._camera.distance)
+
+ def test_moving_camera(self):
+ action = enums.mjtMouse.mjMOUSE_MOVE_V
+ offset = [0.1, -0.2]
+ with mock.patch(renderer.__name__ + '.mujoco') as mock_mujoco:
+ self.camera.move(action, offset)
+ mock_mujoco.mjv_moveCamera.assert_called_once()
+
+ def test_zoom_to_scene(self):
+ scene_center = np.array([1, 2, 3])
+ scene_extents = np.array([10, 20, 30])
+
+ self.camera.look_at = mock.MagicMock()
+ self.model.stat = mock.MagicMock()
+ self.model.stat.center = scene_center
+ self.model.stat.extent = scene_extents
+
+ self.camera.zoom_to_scene()
+ self.camera.look_at.assert_called_once()
+ np.testing.assert_array_equal(
+ scene_center, self.camera.look_at.call_args[0][0])
+ np.testing.assert_array_equal(
+ scene_extents * 1.5, self.camera.look_at.call_args[0][1])
+
+ def test_camera_transform(self):
+ self.camera._scene.camera[0].up[:] = [0, 1, 0]
+ self.camera._scene.camera[0].forward[:] = [0, 0, 1]
+ self.camera._scene.camera[0].pos[:] = [5, 0, 0]
+ self.camera._scene.camera[1].pos[:] = [10, 0, 0]
+
+ rotation_mtx, position = self.camera.transform
+ np.testing.assert_array_equal([-1, 0, 0], rotation_mtx[0])
+ np.testing.assert_array_equal([0, 1, 0], rotation_mtx[1])
+ np.testing.assert_array_equal([0, 0, 1], rotation_mtx[2])
+ np.testing.assert_array_equal([7.5, 0, 0], position)
+
+ @parameterized.parameters(
+ (0, 0, False),
+ (0, 1, False),
+ (1, 0, False),
+ (2, 1, False),
+ (1, 2, True))
+ def test_is_camera_initialized(self, frustum_near, frustum_far, result):
+ gl_camera = mock.MagicMock()
+ self.camera._scene = mock.MagicMock()
+ self.camera._scene.camera = [gl_camera]
+
+ gl_camera.frustum_near = frustum_near
+ gl_camera.frustum_far = frustum_far
+ self.assertEqual(result, self.camera.is_initialized)
+
+
+@absltest.skip('b/222664582')
+class RaycastsTest(absltest.TestCase):
+
+ @mock.patch.object(renderer.wrapper.core,
+ '_estimate_max_renderable_geoms',
+ return_value=1000)
+ @mock.patch.object(renderer.wrapper.core.mujoco, 'MjvScene')
+ def setUp(self, mock_make_scene, _):
+ super().setUp()
+ self.model = mock.MagicMock()
+ self.data = mock.MagicMock()
+ self.options = mock.MagicMock()
+
+ self.viewport = mock.MagicMock()
+ self.camera = renderer.SceneCamera(self.model, self.data, self.options)
+ mock_make_scene.assert_called_once()
+ self.initialize_camera(True)
+
+ def initialize_camera(self, enable):
+ gl_camera = mock.MagicMock()
+ self.camera._scene = mock.MagicMock()
+ self.camera._scene.camera = [gl_camera]
+ gl_camera.frustum_near = 1 if enable else 0
+ gl_camera.frustum_far = 2 if enable else 0
+
+ def test_raycast_mapping_geom_to_body_id(self):
+ def build_mjv_select(mock_body_id, mock_geom_id, mock_position):
+
+ def mock_select(
+ m,
+ d,
+ vopt,
+ aspectratio,
+ relx,
+ rely,
+ scn,
+ selpnt,
+ geomid,
+ flexid,
+ skinid,
+ ):
+ del m, d, vopt, aspectratio, relx, rely, scn, flexid, skinid # Unused.
+ selpnt[:] = mock_position
+ geomid[:] = mock_geom_id
+ return mock_body_id
+
+ return mock_select
+
+ geom_id = 0
+ body_id = 5
+ world_pos = [1, 2, 3]
+ self.model.geom_bodyid = np.zeros(10)
+ self.model.geom_bodyid[geom_id] = body_id
+ mock_select = build_mjv_select(body_id, geom_id, world_pos)
+
+ with mock.patch(renderer.__name__ + '.mujoco') as mock_mujoco:
+ mock_mujoco.mjv_select = mock.MagicMock(side_effect=mock_select)
+ hit_body_id, hit_world_pos = self.camera.raycast(self.viewport, [0, 0])
+ self.assertEqual(hit_body_id, body_id)
+ np.testing.assert_array_equal(hit_world_pos, world_pos)
+
+ def test_raycast_hitting_empty_space(self):
+
+ def mock_select(
+ m, d, vopt, aspectratio, relx, rely, scn, selpnt, geomid, flexid, skinid
+ ):
+ del (m, d, vopt, aspectratio, relx, rely, scn, selpnt, geomid, flexid,
+ skinid) # Unused.
+ mock_body_id = -1 # Nothing selected.
+ return mock_body_id
+
+ with mock.patch(renderer.__name__ + '.mujoco') as mock_mujoco:
+ mock_mujoco.mjv_select = mock.MagicMock(side_effect=mock_select)
+ hit_body_id, hit_world_pos = self.camera.raycast(self.viewport, [0, 0])
+ self.assertEqual(-1, hit_body_id)
+ self.assertIsNone(hit_world_pos)
+
+ def test_raycast_maps_coordinates_to_viewport_space(self):
+ def build_mjv_select(expected_aspect_ratio, expected_viewport_pos):
+
+ def mock_select(
+ m,
+ d,
+ vopt,
+ aspectratio,
+ relx,
+ rely,
+ scn,
+ selpnt,
+ geomid,
+ flexid,
+ skinid,
+ ):
+ del m, d, vopt, scn, selpnt, geomid, flexid, skinid # Unused.
+ self.assertEqual(expected_aspect_ratio, aspectratio)
+ np.testing.assert_array_equal(expected_viewport_pos, [relx, rely])
+ mock_body_id = 0
+ return mock_body_id
+
+ return mock_select
+
+ viewport_pos = [.5, .5]
+ self.viewport.screen_to_inverse_viewport.return_value = viewport_pos
+ mock_select = build_mjv_select(self.viewport.aspect_ratio, viewport_pos)
+
+ with mock.patch(renderer.__name__ + '.mujoco') as mock_mujoco:
+ mock_mujoco.mjv_select = mock.MagicMock(side_effect=mock_select)
+ self.camera.raycast(self.viewport, [50, 25])
+
+ def test_raycasts_disabled_when_camera_is_not_initialized(self):
+ self.initialize_camera(False)
+ hit_body_id, hit_world_pos = self.camera.raycast(self.viewport, [0, 0])
+ self.assertEqual(-1, hit_body_id)
+ self.assertIsNone(hit_world_pos)
+
+
+class ViewportTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.viewport = renderer.Viewport()
+ self.viewport.set_size(100, 100)
+
+ @parameterized.parameters(
+ ([0, 0], [0., 0.]),
+ ([100, 0], [1., 0.]),
+ ([0, 100], [0., 1.]),
+ ([50, 50], [.5, .5]))
+ def test_screen_to_viewport(self, screen_coords, viewport_coords):
+ np.testing.assert_array_equal(
+ viewport_coords, self.viewport.screen_to_viewport(screen_coords))
+
+ @parameterized.parameters(
+ ([0, 0], [0., 1.]),
+ ([100, 0], [1., 1.]),
+ ([0, 100], [0., 0.]),
+ ([50, 50], [.5, .5]))
+ def test_screen_to_inverse_viewport(self, screen_coords, viewport_coords):
+ np.testing.assert_array_equal(
+ viewport_coords,
+ self.viewport.screen_to_inverse_viewport(screen_coords))
+
+ @parameterized.parameters(
+ ([10, 10], 1.),
+ ([30, 40], 3./4.))
+ def test_aspect_ratio(self, screen_size, aspect_ratio):
+ self.viewport.set_size(screen_size[0], screen_size[1])
+ self.assertEqual(aspect_ratio, self.viewport.aspect_ratio)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/runtime.py b/dm_control/viewer/runtime.py
new file mode 100644
index 00000000..e7477dbd
--- /dev/null
+++ b/dm_control/viewer/runtime.py
@@ -0,0 +1,257 @@
+# Copyright 2018-2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Environment's execution runtime."""
+
+import collections
+import copy
+import enum
+
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.viewer import util
+import mujoco
+import numpy as np
+
+mjlib = mjbindings.mjlib
+
+
+# Pause interval between simulation steps.
+_SIMULATION_STEP_INTERVAL = 0.001
+
+# The longest allowed simulation time step, in seconds.
+_DEFAULT_MAX_SIM_STEP = 1./5.
+
+
+def _get_default_action(action_spec):
+ """Generates an action to apply to the environment if there is no agent.
+
+ * For action dimensions that are closed intervals this will be the midpoint.
+ * For left-open or right-open intervals this will be the maximum or the
+ minimum respectively.
+ * For unbounded intervals this will be zero.
+
+ Args:
+ action_spec: An instance of `BoundedArraySpec` or a list or tuple
+ containing these.
+
+ Returns:
+ A numpy array of actions if `action_spec` is a single `BoundedArraySpec`, or
+ a tuple of such arrays if `action_spec` is a list or tuple.
+ """
+ if isinstance(action_spec, (list, tuple)):
+ return tuple(_get_default_action(spec) for spec in action_spec)
+ elif isinstance(action_spec, collections.abc.MutableMapping):
+ # Clones the Mapping, preserving type and key order.
+ result = copy.copy(action_spec)
+
+ for key, value in action_spec.items():
+ result[key] = _get_default_action(value)
+
+ return result
+
+ minimum = np.broadcast_to(action_spec.minimum, action_spec.shape)
+ maximum = np.broadcast_to(action_spec.maximum, action_spec.shape)
+ left_bounded = np.isfinite(minimum)
+ right_bounded = np.isfinite(maximum)
+ action = np.select(
+ condlist=[left_bounded & right_bounded, left_bounded, right_bounded],
+ choicelist=[0.5 * (minimum + maximum), minimum, maximum],
+ default=0.)
+ action = action.astype(action_spec.dtype, copy=False)
+ action.flags.writeable = False
+ return action
+
+
+class State(enum.Enum):
+ """State of the Runtime class."""
+ START = 0
+ RUNNING = 1
+ STOP = 2
+ STOPPED = 3
+ RESTARTING = 4
+
+
+class Runtime:
+ """Base Runtime class.
+
+ Attributes:
+ simulation_time_budget: Float value, how much time can be spent on physics
+ simulation every frame, in seconds.
+ on_episode_begin: An observable subject, an instance of util.QuietSet.
+ It contains argumentless callables, invoked, when a new episode begins.
+ on_error: An observable subject, an instance of util.QuietSet. It contains
+ single argument callables, invoked, when the environment or the agent
+ throw an error.
+ on_physics_changed: An observable subject, an instance of util.QuietSet.
+ During episode restarts, the underlying physics instance may change. If
+ you are interested in learning about those changes, attach a listener
+ using the += operator. The listener should be a callable with no required
+ arguments.
+ """
+
+ def __init__(self, environment, policy=None):
+ """Instance initializer.
+
+ Args:
+ environment: An instance of dm_control.rl.control.Environment.
+ policy: Either a callable that accepts a `TimeStep` and returns a numpy
+ array of actions conforming to `environment.action_spec()`, or None, in
+ which case a default action will be generated for each environment step.
+ """
+ self.on_error = util.QuietSet()
+ self.on_episode_begin = util.QuietSet()
+ self.simulation_time_budget = _DEFAULT_MAX_SIM_STEP
+
+ self._state = State.START
+ self._simulation_timer = util.Timer()
+ self._tracked_simulation_time = 0.0
+ self._error_logger = util.ErrorLogger(self.on_error)
+
+ self._env = environment
+ self._policy = policy
+ self._default_action = _get_default_action(environment.action_spec())
+ self._time_step = None
+ self._last_action = None
+ self.on_physics_changed = util.QuietSet()
+
+ def tick(self, time_elapsed, paused):
+ """Advances the simulation by one frame.
+
+ Args:
+ time_elapsed: Time elapsed since the last time this method was called.
+ paused: A boolean flag telling if the simulation is paused.
+ Returns:
+ A boolean flag to determine if the episode has finished.
+ """
+ with self._simulation_timer.measure_time():
+ if self._state == State.RESTARTING:
+ self._state = State.START
+ if self._state == State.START:
+ if self._start():
+ self._broadcast_episode_start()
+ self._tracked_simulation_time = self.get_time()
+ self._state = State.RUNNING
+ else:
+ self._state = State.STOPPED
+ if self._state == State.RUNNING:
+ finished = self._step_simulation(time_elapsed, paused)
+ if finished:
+ self._state = State.STOP
+ if self._state == State.STOP:
+ self._state = State.STOPPED
+
+ def _step_simulation(self, time_elapsed, paused):
+ """Simulate a simulation step."""
+ finished = False
+ if paused:
+ self._step_paused()
+ else:
+ step_duration = min(time_elapsed, self.simulation_time_budget)
+ actual_simulation_time = self.get_time()
+ if self._tracked_simulation_time >= actual_simulation_time:
+ end_time = actual_simulation_time + step_duration
+ while not finished and self.get_time() < end_time:
+ finished = self._step()
+ self._tracked_simulation_time += step_duration
+ return finished
+
+ def single_step(self):
+ """Performs a single step of simulation."""
+ if self._state == State.RUNNING:
+ finished = self._step()
+ self._state = State.STOP if finished else State.RUNNING
+
+ def stop(self):
+ """Stops the runtime."""
+ self._state = State.STOPPED
+
+ def restart(self):
+ """Restarts the episode, resetting environment, model, and data."""
+ if self._state != State.STOPPED:
+ self._state = State.RESTARTING
+ else:
+ self._state = State.START
+
+ def get_time(self):
+ """Elapsed simulation time."""
+ return self._env.physics.data.time
+
+ @property
+ def state(self):
+ """Returns the current state of the state machine.
+
+ Returned states are values of runtime.State enum.
+ """
+ return self._state
+
+ @property
+ def simulation_time(self):
+ """Returns the amount of time spent running the simulation."""
+ return self._simulation_timer.measured_time
+
+ @property
+ def last_action(self):
+ """Action passed to the environment on the last step."""
+ return self._last_action
+
+ def _broadcast_episode_start(self):
+ for listener in self.on_episode_begin:
+ listener()
+
+ def _start(self):
+ """Starts a new simulation episode.
+
+ Starting a new episode may be associated with changing the physics instance.
+ The method tracks that and notifies observers through 'on_physics_changed'
+ subject.
+
+ Returns:
+ True if the operation was successful, False otherwise.
+ """
+ # NB: we check the identity of the data pointer rather than the physics
+ # instance itself, since this allows us to detect when the physics has been
+ # "reloaded" using one of the `reload_from_*` methods.
+ old_data_ptr = self._env.physics.data.ptr
+
+ with self._error_logger:
+ self._time_step = self._env.reset()
+
+ if self._env.physics.data.ptr is not old_data_ptr:
+ for listener in self.on_physics_changed:
+ listener()
+ return not self._error_logger.errors_found
+
+ def _step_paused(self):
+ mujoco.mj_forward(self._env.physics.model.ptr, self._env.physics.data.ptr)
+
+ def _step(self):
+ """Generates an action and applies it to the environment.
+
+ If a `policy` was provided, this will be invoked to generate an action to
+ feed to the environment, otherwise a default action will be generated.
+
+ Returns:
+ A boolean value, True if the environment signaled the episode end, False
+ if the episode is still running.
+ """
+ finished = True
+ with self._error_logger:
+ if self._policy:
+ action = self._policy(self._time_step)
+ else:
+ action = self._default_action
+ self._time_step = self._env.step(action)
+ self._last_action = action
+ finished = self._time_step.last()
+ return finished or self._error_logger.errors_found
diff --git a/dm_control/viewer/runtime_test.py b/dm_control/viewer/runtime_test.py
new file mode 100644
index 00000000..3d4fd46d
--- /dev/null
+++ b/dm_control/viewer/runtime_test.py
@@ -0,0 +1,399 @@
+# Copyright 2018-2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Runtime tests."""
+
+import collections
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.viewer import runtime
+import dm_env
+from dm_env import specs
+import mock
+import numpy as np
+
+
+class RuntimeStateMachineTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ env = mock.MagicMock()
+ env.action_spec.return_value = specs.BoundedArray((1,), np.float64, -1, 1)
+ self.runtime = runtime.Runtime(env, mock.MagicMock())
+ self.runtime._start = mock.MagicMock()
+ self.runtime.get_time = mock.MagicMock()
+ self.runtime.get_time.return_value = 0
+ self.runtime._step_simulation = mock.MagicMock(return_value=False)
+
+ def test_initial_state(self):
+ self.assertEqual(self.runtime._state, runtime.State.START)
+
+ def test_successful_starting(self):
+ self.runtime._start.return_value = True
+ self.runtime._state = runtime.State.START
+ self.runtime.tick(0, False)
+ self.assertEqual(self.runtime._state, runtime.State.RUNNING)
+ self.runtime._start.assert_called_once()
+ self.runtime._step_simulation.assert_called_once()
+
+ def test_failure_during_start(self):
+ self.runtime._start.return_value = False
+ self.runtime._state = runtime.State.START
+ self.runtime.tick(0, False)
+ self.assertEqual(self.runtime._state, runtime.State.STOPPED)
+ self.runtime._start.assert_called_once()
+ self.runtime._step_simulation.assert_not_called()
+
+ def test_restarting(self):
+ self.runtime._state = runtime.State.RUNNING
+ self.runtime.restart()
+ self.runtime.tick(0, False)
+ self.assertEqual(self.runtime._state, runtime.State.RUNNING)
+ self.runtime._start.assert_called_once()
+ self.runtime._step_simulation.assert_called_once()
+
+ def test_running(self):
+ self.runtime._state = runtime.State.RUNNING
+ self.runtime.tick(0, False)
+ self.assertEqual(self.runtime._state, runtime.State.RUNNING)
+ self.runtime._step_simulation.assert_called_once()
+
+ def test_ending_a_running_episode(self):
+ self.runtime._state = runtime.State.RUNNING
+ self.runtime._step_simulation.return_value = True
+ self.runtime.tick(0, False)
+ self.assertEqual(self.runtime._state, runtime.State.STOPPED)
+ self.runtime._step_simulation.assert_called_once()
+
+ def test_calling_stop_has_immediate_effect_on_state(self):
+ self.runtime.stop()
+ self.assertEqual(self.runtime._state, runtime.State.STOPPED)
+
+ @parameterized.parameters(runtime.State.RUNNING,
+ runtime.State.RESTARTING,)
+ def test_states_affected_by_stop(self, state):
+ self.runtime._state = state
+ self.runtime.stop()
+ self.assertEqual(self.runtime._state, runtime.State.STOPPED)
+
+ def test_notifying_listeners_about_successful_start(self):
+ callback = mock.MagicMock()
+ self.runtime.on_episode_begin += [callback]
+ self.runtime._start.return_value = True
+ self.runtime._state = runtime.State.START
+ self.runtime.tick(0, False)
+ callback.assert_called_once()
+
+ def test_listeners_not_notified_when_start_fails(self):
+ callback = mock.MagicMock()
+ self.runtime.on_episode_begin += [callback]
+ self.runtime._start.return_value = False
+ self.runtime._state = runtime.State.START
+ self.runtime.tick(0, False)
+ callback.assert_not_called()
+
+
+class RuntimeSingleStepTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ env = mock.MagicMock(spec=dm_env.Environment)
+ env.action_spec.return_value = specs.BoundedArray((1,), np.float64, -1, 1)
+ self.runtime = runtime.Runtime(env, mock.MagicMock())
+ self.runtime._step = mock.MagicMock()
+ self.runtime._step.return_value = False
+
+ def test_when_running(self):
+ self.runtime._state = runtime.State.RUNNING
+ self.runtime.single_step()
+ self.assertEqual(self.runtime._state, runtime.State.RUNNING)
+ self.runtime._step.assert_called_once()
+
+ def test_ending_episode(self):
+ self.runtime._state = runtime.State.RUNNING
+ self.runtime._step.return_value = True
+ self.runtime.single_step()
+ self.assertEqual(self.runtime._state, runtime.State.STOP)
+ self.runtime._step.assert_called_once()
+
+ @parameterized.parameters(runtime.State.START,
+ runtime.State.STOP,
+ runtime.State.STOPPED,
+ runtime.State.RESTARTING)
+ def test_runs_only_in_running_state(self, state):
+ self.runtime._state = state
+ self.runtime.single_step()
+ self.assertEqual(self.runtime._state, state)
+ self.assertEqual(0, self.runtime._step.call_count)
+
+
+class RuntimeTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ env = mock.MagicMock(spec=dm_env.Environment)
+ env.action_spec.return_value = specs.BoundedArray((1,), np.float64, -1, 1)
+ self.runtime = runtime.Runtime(env, mock.MagicMock())
+ self.runtime._step_paused = mock.MagicMock()
+ self.runtime._step = mock.MagicMock(return_value=True)
+ self.runtime.get_time = mock.MagicMock(return_value=0)
+
+ self.time_step = 1e-2
+
+ def set_loop(self, num_iterations, finish_after=0):
+ finish_after = finish_after or num_iterations + 1
+ self.delta_time = self.time_step / float(num_iterations)
+ self.time = 0
+ self.iteration = 0
+ def fakeget_time():
+ return self.time
+ def fake_step():
+ self.time += self.delta_time
+ self.iteration += 1
+ return self.iteration >= finish_after
+
+ self.runtime._step = mock.MagicMock(side_effect=fake_step)
+ self.runtime.get_time = mock.MagicMock(side_effect=fakeget_time)
+
+ def test_num_step_calls(self):
+ expected_call_count = 5
+ self.set_loop(num_iterations=expected_call_count)
+ finished = self.runtime._step_simulation(self.time_step, False)
+ self.assertFalse(finished)
+ self.assertEqual(expected_call_count, self.runtime._step.call_count)
+
+ def test_finishing_if_episode_ends(self):
+ num_iterations = 5
+ finish_after = 2
+ self.set_loop(num_iterations=num_iterations, finish_after=finish_after)
+ finished = self.runtime._step_simulation(self.time_step, False)
+ self.assertTrue(finished)
+ self.assertEqual(finish_after, self.runtime._step.call_count)
+
+ def test_stepping_paused(self):
+ self.runtime._step_simulation(0, True)
+ self.runtime._step_paused.assert_called_once()
+ self.assertEqual(0, self.runtime._step.call_count)
+
+ def test_physics_step_takes_less_time_than_tick(self):
+ self.physics_time_step = runtime._DEFAULT_MAX_SIM_STEP * 0.5
+ self.physics_time = 0.0
+ def mock_get_time():
+ return self.physics_time
+ def mock_step():
+ self.physics_time += self.physics_time_step
+ self.runtime._step = mock.MagicMock(side_effect=mock_step)
+ self.runtime.get_time = mock.MagicMock(side_effect=mock_get_time)
+ self.runtime._step_simulation(
+ time_elapsed=runtime._DEFAULT_MAX_SIM_STEP, paused=False)
+ self.assertEqual(2, self.runtime._step.call_count)
+
+ def test_physics_step_takes_more_time_than_tick(self):
+ self.physics_time_step = runtime._DEFAULT_MAX_SIM_STEP * 2
+ self.physics_time = 0.0
+ def mock_get_time():
+ return self.physics_time
+ def mock_step():
+ self.physics_time += self.physics_time_step
+ self.runtime._step = mock.MagicMock(side_effect=mock_step)
+ self.runtime.get_time = mock.MagicMock(side_effect=mock_get_time)
+
+ # Simulates after the first frame
+ self.runtime._step_simulation(
+ time_elapsed=runtime._DEFAULT_MAX_SIM_STEP, paused=False)
+ self.assertEqual(1, self.runtime._step.call_count)
+ self.runtime._step.reset_mock()
+
+ # Then pauses for one frame to let the internal timer catch up with the
+ # simulation timer.
+ self.runtime._step_simulation(
+ time_elapsed=runtime._DEFAULT_MAX_SIM_STEP, paused=False)
+ self.assertEqual(0, self.runtime._step.call_count)
+
+ # Resumes simulation on the subsequent frame.
+ self.runtime._step_simulation(
+ time_elapsed=runtime._DEFAULT_MAX_SIM_STEP, paused=False)
+ self.assertEqual(1, self.runtime._step.call_count)
+
+ def test_updating_tracked_time_during_start(self):
+ invalid_time = 20
+ self.runtime.get_time = mock.MagicMock(return_value=invalid_time)
+
+ valid_time = 2
+ def mock_start():
+ self.runtime.get_time = mock.MagicMock(return_value=valid_time)
+ return True
+
+ self.runtime._start = mock.MagicMock(side_effect=mock_start)
+ self.runtime._step_simulation = mock.MagicMock()
+
+ self.runtime.tick(time_elapsed=runtime._DEFAULT_MAX_SIM_STEP, paused=False)
+ self.assertEqual(valid_time, self.runtime._tracked_simulation_time)
+
+ def test_error_logger_forward_errors_to_listeners(self):
+ callback = mock.MagicMock()
+ self.runtime.on_error += [callback]
+ with self.runtime._error_logger:
+ raise Exception('error message')
+ callback.assert_called_once()
+
+
+class EnvironmentRuntimeTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.observation = mock.MagicMock()
+ self.env = mock.MagicMock(spec=dm_env.Environment)
+ self.env.physics = mock.MagicMock()
+ self.env.step = mock.MagicMock()
+ self.env.action_spec.return_value = specs.BoundedArray(
+ (1,), np.float64, -1, 1)
+ self.policy = mock.MagicMock()
+ self.actions = mock.MagicMock()
+ self.runtime = runtime.Runtime(self.env, self.policy)
+
+ def test_start(self):
+ with mock.patch(runtime.__name__ + '.mujoco'):
+ result = self.runtime._start()
+ self.assertTrue(result)
+ self.env.reset.assert_called_once()
+ self.policy.assert_not_called()
+
+ def test_step_with_policy(self):
+ time_step = mock.Mock(spec=dm_env.TimeStep)
+ self.runtime._time_step = time_step
+ self.runtime._step()
+ self.policy.assert_called_once_with(time_step)
+ self.env.step.assert_called_once_with(self.policy.return_value)
+
+ def test_step_without_policy(self):
+ with mock.patch(
+ runtime.__name__ + '._get_default_action') as mock_get_default_action:
+ this_runtime = runtime.Runtime(environment=self.env, policy=None)
+ this_runtime._step()
+ self.env.step.assert_called_once_with(mock_get_default_action.return_value)
+
+ def test_stepping_paused(self):
+ with mock.patch(runtime.__name__ + '.mujoco') as mujoco:
+ self.runtime._step_paused()
+ mujoco.mj_forward.assert_called_once()
+
+ def test_get_time(self):
+ expected_time = 20
+ self.env.physics = mock.MagicMock()
+ self.env.physics.data = mock.MagicMock()
+ self.env.physics.data.time = expected_time
+ self.assertEqual(expected_time, self.runtime.get_time())
+
+ def test_tracking_physics_instance_changes(self):
+ callback = mock.MagicMock()
+ self.runtime.on_physics_changed += [callback]
+
+ def begin_episode_and_reload_physics():
+ self.env.physics.data.ptr = mock.MagicMock()
+ self.env.reset.side_effect = begin_episode_and_reload_physics
+
+ self.runtime._start()
+ callback.assert_called_once_with()
+
+ def test_tracking_physics_instance_that_doesnt_change(self):
+ callback = mock.MagicMock()
+ self.runtime.on_physics_changed += [callback]
+
+ self.runtime._start()
+ callback.assert_not_called()
+
+ def test_exception_thrown_during_start(self):
+ def raise_exception(*unused_args, **unused_kwargs):
+ raise Exception('test error message')
+ self.runtime._env.reset.side_effect = raise_exception
+ result = self.runtime._start()
+ self.assertFalse(result)
+
+ def test_exception_thrown_during_step(self):
+ def raise_exception(*unused_args, **unused_kwargs):
+ raise Exception('test error message')
+ self.runtime._env.step.side_effect = raise_exception
+ finished = self.runtime._step()
+ self.assertTrue(finished)
+
+
+class DefaultActionFromSpecTest(parameterized.TestCase):
+
+ def assertNestedArraysEqual(self, expected, actual):
+ """Asserts that two potentially nested structures of arrays are equal."""
+ self.assertIs(type(actual), type(expected))
+ if isinstance(expected, (list, tuple)):
+ self.assertIsInstance(actual, (list, tuple))
+ self.assertLen(actual, len(expected))
+ for expected_item, actual_item in zip(expected, actual):
+ self.assertNestedArraysEqual(expected_item, actual_item)
+ elif isinstance(expected, collections.abc.MutableMapping):
+ keys_type = list if isinstance(expected, collections.OrderedDict) else set
+ self.assertEqual(keys_type(actual.keys()), keys_type(expected.keys()))
+ for key, expected_value in expected.items():
+ self.assertNestedArraysEqual(actual[key], expected_value)
+ else:
+ np.testing.assert_array_equal(expected, actual)
+
+ _SHAPE = (2,)
+ _DTYPE = np.float64
+ _ACTION = np.zeros(_SHAPE)
+ _ACTION_SPEC = specs.BoundedArray(_SHAPE, np.float64, -1, 1)
+
+ @parameterized.named_parameters(
+ ('single_array', _ACTION_SPEC, _ACTION),
+ ('tuple', (_ACTION_SPEC, _ACTION_SPEC), (_ACTION, _ACTION)),
+ ('list', [_ACTION_SPEC, _ACTION_SPEC], (_ACTION, _ACTION)),
+ ('dict',
+ {'a': _ACTION_SPEC, 'b': _ACTION_SPEC},
+ {'a': _ACTION, 'b': _ACTION}),
+ ('OrderedDict',
+ collections.OrderedDict([('a', _ACTION_SPEC), ('b', _ACTION_SPEC)]),
+ collections.OrderedDict([('a', _ACTION), ('b', _ACTION)])),
+ )
+ def test_action_structure(self, action_spec, expected_action):
+ self.assertNestedArraysEqual(expected_action,
+ runtime._get_default_action(action_spec))
+
+ def test_ordered_dict_action_structure_with_bad_ordering(self):
+ reversed_spec = collections.OrderedDict([('a', self._ACTION_SPEC),
+ ('b', self._ACTION_SPEC)])
+ expected_action = collections.OrderedDict([('b', self._ACTION),
+ ('a', self._ACTION)])
+ with self.assertRaisesRegex(
+ AssertionError, r"Lists differ: \['a', 'b'\] != \['b', 'a'\]"):
+ self.assertNestedArraysEqual(expected_action,
+ runtime._get_default_action(reversed_spec))
+
+ @parameterized.named_parameters(
+ ('closed',
+ specs.BoundedArray(_SHAPE, _DTYPE, minimum=1., maximum=2.),
+ np.full(_SHAPE, fill_value=1.5, dtype=_DTYPE)),
+ ('left_open',
+ specs.BoundedArray(_SHAPE, _DTYPE, minimum=-np.inf, maximum=2.),
+ np.full(_SHAPE, fill_value=2., dtype=_DTYPE)),
+ ('right_open',
+ specs.BoundedArray(_SHAPE, _DTYPE, minimum=1., maximum=np.inf),
+ np.full(_SHAPE, fill_value=1., dtype=_DTYPE)),
+ ('unbounded',
+ specs.BoundedArray(_SHAPE, _DTYPE, minimum=-np.inf, maximum=np.inf),
+ np.full(_SHAPE, fill_value=0., dtype=_DTYPE)))
+ def test_action_spec_interval(self, action_spec, expected_action):
+ self.assertNestedArraysEqual(expected_action,
+ runtime._get_default_action(action_spec))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/user_input.py b/dm_control/viewer/user_input.py
new file mode 100644
index 00000000..33138d88
--- /dev/null
+++ b/dm_control/viewer/user_input.py
@@ -0,0 +1,310 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Utilities for handling keyboard events."""
+
+import collections
+
+
+# Mapped input values, so that we don't have to reference glfw everywhere.
+RELEASE = 0
+PRESS = 1
+REPEAT = 2
+
+KEY_UNKNOWN = -1
+KEY_SPACE = 32
+KEY_APOSTROPHE = 39
+KEY_COMMA = 44
+KEY_MINUS = 45
+KEY_PERIOD = 46
+KEY_SLASH = 47
+KEY_0 = 48
+KEY_1 = 49
+KEY_2 = 50
+KEY_3 = 51
+KEY_4 = 52
+KEY_5 = 53
+KEY_6 = 54
+KEY_7 = 55
+KEY_8 = 56
+KEY_9 = 57
+KEY_SEMICOLON = 59
+KEY_EQUAL = 61
+KEY_A = 65
+KEY_B = 66
+KEY_C = 67
+KEY_D = 68
+KEY_E = 69
+KEY_F = 70
+KEY_G = 71
+KEY_H = 72
+KEY_I = 73
+KEY_J = 74
+KEY_K = 75
+KEY_L = 76
+KEY_M = 77
+KEY_N = 78
+KEY_O = 79
+KEY_P = 80
+KEY_Q = 81
+KEY_R = 82
+KEY_S = 83
+KEY_T = 84
+KEY_U = 85
+KEY_V = 86
+KEY_W = 87
+KEY_X = 88
+KEY_Y = 89
+KEY_Z = 90
+KEY_LEFT_BRACKET = 91
+KEY_BACKSLASH = 92
+KEY_RIGHT_BRACKET = 93
+KEY_GRAVE_ACCENT = 96
+KEY_ESCAPE = 256
+KEY_ENTER = 257
+KEY_TAB = 258
+KEY_BACKSPACE = 259
+KEY_INSERT = 260
+KEY_DELETE = 261
+KEY_RIGHT = 262
+KEY_LEFT = 263
+KEY_DOWN = 264
+KEY_UP = 265
+KEY_PAGE_UP = 266
+KEY_PAGE_DOWN = 267
+KEY_HOME = 268
+KEY_END = 269
+KEY_CAPS_LOCK = 280
+KEY_SCROLL_LOCK = 281
+KEY_NUM_LOCK = 282
+KEY_PRINT_SCREEN = 283
+KEY_PAUSE = 284
+KEY_F1 = 290
+KEY_F2 = 291
+KEY_F3 = 292
+KEY_F4 = 293
+KEY_F5 = 294
+KEY_F6 = 295
+KEY_F7 = 296
+KEY_F8 = 297
+KEY_F9 = 298
+KEY_F10 = 299
+KEY_F11 = 300
+KEY_F12 = 301
+KEY_KP_0 = 320
+KEY_KP_1 = 321
+KEY_KP_2 = 322
+KEY_KP_3 = 323
+KEY_KP_4 = 324
+KEY_KP_5 = 325
+KEY_KP_6 = 326
+KEY_KP_7 = 327
+KEY_KP_8 = 328
+KEY_KP_9 = 329
+KEY_KP_DECIMAL = 330
+KEY_KP_DIVIDE = 331
+KEY_KP_MULTIPLY = 332
+KEY_KP_SUBTRACT = 333
+KEY_KP_ADD = 334
+KEY_KP_ENTER = 335
+KEY_KP_EQUAL = 336
+KEY_LEFT_SHIFT = 340
+KEY_LEFT_CONTROL = 341
+KEY_LEFT_ALT = 342
+KEY_LEFT_SUPER = 343
+KEY_RIGHT_SHIFT = 344
+KEY_RIGHT_CONTROL = 345
+KEY_RIGHT_ALT = 346
+KEY_RIGHT_SUPER = 347
+
+MOD_NONE = 0
+MOD_SHIFT = 0x0001
+MOD_CONTROL = 0x0002
+MOD_ALT = 0x0004
+MOD_SUPER = 0x0008
+MOD_SHIFT_CONTROL = MOD_SHIFT | MOD_CONTROL
+
+MOUSE_BUTTON_LEFT = 0
+MOUSE_BUTTON_RIGHT = 1
+MOUSE_BUTTON_MIDDLE = 2
+
+_NO_EXCLUSIVE_KEY = (None, lambda _: None)
+_NO_CALLBACK = (None, None)
+
+
+class Exclusive(collections.namedtuple('Exclusive', 'combination')):
+ """Defines an exclusive action.
+
+ Exclusive actions can be invoked in response to single key clicks only. The
+ callback will be called twice. The first time when the key combination is
+ pressed, passing True as the argument to the callback. The second time when
+ the key is released (the modifiers don't have to be present then), passing
+ False as the callback argument.
+
+ Attributes:
+ combination: A list of integers interpreted as key codes, or tuples
+ in format (keycode, modifier).
+ """
+ pass
+
+
+class DoubleClick(collections.namedtuple('DoubleClick', 'combination')):
+ """Defines a mouse double click action.
+
+ It will define a requirement to double click the mouse button specified in the
+ combination in order to be triggered.
+
+ Attributes:
+ combination: A list of integers interpreted as key codes, or tuples
+ in format (keycode, modifier). The keycodes are limited only to mouse
+ button codes.
+ """
+ pass
+
+
+class Range(collections.namedtuple('Range', 'collection')):
+ """Binds a number of key combinations to a callback.
+
+ When triggered, the index of the triggering key combination will be passed
+ as an argument to the callback.
+
+ Attributes:
+ callback: A callable accepting a single argument - an integer index of the
+ triggered callback.
+ collection: A collection of combinations. Combinations may either be raw key
+ codes, tuples in format (keycode, modifier), or one of the Exclusive or
+ DoubleClick instances.
+ """
+ pass
+
+
+class InputMap:
+ """Provides ability to alias key combinations and map them to actions."""
+
+ def __init__(self, mouse, keyboard):
+ """Instance initializer.
+
+ Args:
+ mouse: GlfwMouse instance.
+ keyboard: GlfwKeyboard instance.
+ """
+ self._keyboard = keyboard
+ self._mouse = mouse
+
+ self._keyboard.on_key += self._handle_key
+ self._mouse.on_click += self._handle_key
+ self._mouse.on_double_click += self._handle_double_click
+ self._mouse.on_move += self._handle_mouse_move
+ self._mouse.on_scroll += self._handle_mouse_scroll
+
+ self.clear_bindings()
+
+ def __del__(self):
+ """Instance deleter."""
+ self._keyboard.on_key -= self._handle_key
+ self._mouse.on_click -= self._handle_key
+ self._mouse.on_double_click -= self._handle_double_click
+ self._mouse.on_move -= self._handle_mouse_move
+ self._mouse.on_scroll -= self._handle_mouse_scroll
+
+ def clear_bindings(self):
+ """Clears registered action bindings, while keeping key aliases."""
+ self._action_callbacks = {}
+ self._double_click_callbacks = {}
+ self._plane_callback = []
+ self._z_axis_callback = []
+ self._active_exclusive = _NO_EXCLUSIVE_KEY
+
+ def bind(self, callback, key_binding):
+ """Binds a key combination to a callback.
+
+ Args:
+ callback: An argument-less callable.
+ key_binding: A integer with a key code, a tuple (keycode, modifier) or one
+ of the actions Exclusive|DoubleClick|Range carrying the key combination.
+ """
+ def build_callback(index, callback):
+ def indexed_callback():
+ callback(index)
+ return indexed_callback
+
+ if isinstance(key_binding, Range):
+ for index, binding in enumerate(key_binding.collection):
+ self._add_binding(build_callback(index, callback), binding)
+ else:
+ self._add_binding(callback, key_binding)
+
+ def _add_binding(self, callback, key_binding):
+ key_combination = self._extract_key_combination(key_binding)
+ if isinstance(key_binding, Exclusive):
+ self._action_callbacks[key_combination] = (True, callback)
+ elif isinstance(key_binding, DoubleClick):
+ self._double_click_callbacks[key_combination] = callback
+ else:
+ self._action_callbacks[key_combination] = (False, callback)
+
+ def _extract_key_combination(self, key_binding):
+ if isinstance(key_binding, Exclusive):
+ key_binding = key_binding.combination
+ elif isinstance(key_binding, DoubleClick):
+ key_binding = key_binding.combination
+
+ if not isinstance(key_binding, tuple):
+ key_binding = (key_binding, MOD_NONE)
+ return key_binding
+
+ def bind_plane(self, callback):
+ """Binds a callback to a planar motion action (mouse movement)."""
+ self._plane_callback.append(callback)
+
+ def bind_z_axis(self, callback):
+ """Binds a callback to a z-axis motion action (mouse scroll)."""
+ self._z_axis_callback.append(callback)
+
+ def _handle_key(self, key, action, modifiers):
+ """Handles a single key press (mouse and keyboard)."""
+ alias_key = (key, modifiers)
+
+ exclusive_key, exclusive_callback = self._active_exclusive
+ if exclusive_key is not None:
+ if action == RELEASE and key == exclusive_key:
+ exclusive_callback(False)
+ self._active_exclusive = _NO_EXCLUSIVE_KEY
+ else:
+ is_exclusive, callback = self._action_callbacks.get(
+ alias_key, _NO_CALLBACK)
+ if callback:
+ if action == PRESS:
+ if is_exclusive:
+ callback(True)
+ self._active_exclusive = (key, callback)
+ else:
+ callback()
+
+ def _handle_double_click(self, key, modifiers):
+ """Handles a double mouse click."""
+ alias_key = (key, modifiers)
+ callback = self._double_click_callbacks.get(alias_key, None)
+ if callback is not None:
+ callback()
+
+ def _handle_mouse_move(self, position, translation):
+ """Handles mouse move."""
+ for callback in self._plane_callback:
+ callback(position, translation)
+
+ def _handle_mouse_scroll(self, value):
+ """Handles mouse wheel scroll."""
+ for callback in self._z_axis_callback:
+ callback(value)
diff --git a/dm_control/viewer/user_input_test.py b/dm_control/viewer/user_input_test.py
new file mode 100644
index 00000000..ff1927b8
--- /dev/null
+++ b/dm_control/viewer/user_input_test.py
@@ -0,0 +1,172 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the user_input module."""
+
+
+from absl.testing import absltest
+from dm_control.viewer import user_input
+import mock
+
+
+class InputMapTests(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.mouse = mock.MagicMock()
+ self.keyboard = mock.MagicMock()
+ self.input_map = user_input.InputMap(self.mouse, self.keyboard)
+
+ self.callback = mock.MagicMock()
+
+ def test_clearing_bindings(self):
+ self.input_map._active_exclusive = 1
+ self.input_map._action_callbacks = {1: 2}
+ self.input_map._double_click_callbacks = {3: 4}
+ self.input_map._plane_callback = [5]
+ self.input_map._z_axis_callback = [6]
+
+ self.input_map.clear_bindings()
+
+ self.assertEmpty(self.input_map._action_callbacks)
+ self.assertEmpty(self.input_map._double_click_callbacks)
+ self.assertEmpty(self.input_map._plane_callback)
+ self.assertEmpty(self.input_map._z_axis_callback)
+ self.assertEqual(
+ user_input._NO_EXCLUSIVE_KEY, self.input_map._active_exclusive)
+
+ def test_binding(self):
+ self.input_map.bind(self.callback, user_input.KEY_UP)
+ expected_dict = {
+ (user_input.KEY_UP, user_input.MOD_NONE): (False, self.callback)}
+ self.assertDictEqual(expected_dict, self.input_map._action_callbacks)
+
+ def test_binding_exclusive(self):
+ self.input_map.bind(self.callback, user_input.Exclusive(user_input.KEY_UP))
+ expected_dict = {
+ (user_input.KEY_UP, user_input.MOD_NONE): (True, self.callback)}
+ self.assertDictEqual(expected_dict, self.input_map._action_callbacks)
+
+ def test_binding_and_invoking_ranges_of_actions(self):
+ self.input_map.bind(self.callback, user_input.Range(
+ [user_input.KEY_UP, (user_input.KEY_UP, user_input.MOD_ALT)]))
+
+ self.input_map._handle_key(
+ user_input.KEY_UP, user_input.PRESS, user_input.MOD_NONE)
+ self.callback.assert_called_once_with(0)
+
+ self.callback.reset_mock()
+ self.input_map._handle_key(
+ user_input.KEY_UP, user_input.PRESS, user_input.MOD_ALT)
+ self.callback.assert_called_once_with(1)
+
+ def test_binding_planar_action(self):
+ self.input_map.bind_plane(self.callback)
+ self.assertLen(self.input_map._plane_callback, 1)
+ self.assertEqual(self.callback, self.input_map._plane_callback[0])
+
+ def test_binding_z_axis_action(self):
+ self.input_map.bind_z_axis(self.callback)
+ self.assertLen(self.input_map._z_axis_callback, 1)
+ self.assertEqual(self.callback, self.input_map._z_axis_callback[0])
+
+ def test_invoking_regular_action_in_response_to_click(self):
+ self.input_map._action_callbacks = {(1, 2): (False, self.callback)}
+
+ self.input_map._handle_key(1, user_input.PRESS, 2)
+ self.callback.assert_called_once()
+ self.callback.reset_mock()
+
+ self.input_map._handle_key(1, user_input.RELEASE, 2)
+ self.assertEqual(0, self.callback.call_count)
+
+ def test_invoking_exclusive_action_in_response_to_click(self):
+ self.input_map._action_callbacks = {(1, 2): (True, self.callback)}
+
+ self.input_map._handle_key(1, user_input.PRESS, 2)
+ self.callback.assert_called_once_with(True)
+ self.callback.reset_mock()
+
+ self.input_map._handle_key(1, user_input.RELEASE, 2)
+ self.callback.assert_called_once_with(False)
+
+ def test_exclusive_action_blocks_other_actions_until_its_finished(self):
+ self.input_map._action_callbacks = {
+ (1, 2): (True, self.callback), (3, 4): (False, self.callback)}
+
+ self.input_map._handle_key(1, user_input.PRESS, 2)
+ self.callback.assert_called_once_with(True)
+ self.callback.reset_mock()
+
+ # Attempting to start other actions (PRESS) or end them (RELEASE)
+ # amounts to nothing.
+ self.input_map._handle_key(3, user_input.PRESS, 4)
+ self.assertEqual(0, self.callback.call_count)
+
+ self.input_map._handle_key(3, user_input.RELEASE, 4)
+ self.assertEqual(0, self.callback.call_count)
+
+ # Even attempting to start the same action for the 2nd time fails.
+ self.input_map._handle_key(1, user_input.PRESS, 2)
+ self.assertEqual(0, self.callback.call_count)
+
+ # Only finishing the action frees up the resources.
+ self.input_map._handle_key(1, user_input.RELEASE, 2)
+ self.callback.assert_called_once_with(False)
+ self.callback.reset_mock()
+
+ # Now we can start a new action.
+ self.input_map._handle_key(3, user_input.PRESS, 4)
+ self.callback.assert_called_once()
+
+ def test_modifiers_required_only_for_exclusive_action_start(self):
+ activation_modifiers = 2
+ no_modifiers = 0
+ self.input_map._action_callbacks = {
+ (1, activation_modifiers): (True, self.callback)}
+
+ self.input_map._handle_key(1, user_input.PRESS, activation_modifiers)
+ self.callback.assert_called_once_with(True)
+ self.callback.reset_mock()
+
+ self.input_map._handle_key(1, user_input.RELEASE, no_modifiers)
+ self.callback.assert_called_once_with(False)
+
+ def test_invoking_regular_action_in_response_to_double_click(self):
+ self.input_map._double_click_callbacks = {(1, 2): self.callback}
+ self.input_map._handle_double_click(1, 2)
+ self.callback.assert_called_once()
+
+ def test_exclusive_actions_dont_respond_to_double_clicks(self):
+ self.input_map._action_callbacks = {(1, 2): (True, self.callback)}
+
+ self.input_map._handle_double_click(1, 2)
+ self.assertEqual(0, self.callback.call_count)
+
+ def test_mouse_move(self):
+ position = [1, 2]
+ translation = [3, 4]
+ self.input_map._plane_callback = [self.callback]
+ self.input_map._handle_mouse_move(position, translation)
+ self.callback.assert_called_once_with(position, translation)
+
+ def test_mouse_scroll(self):
+ value = 5
+ self.input_map._z_axis_callback = [self.callback]
+ self.input_map._handle_mouse_scroll(value)
+ self.callback.assert_called_once_with(value)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/util.py b/dm_control/viewer/util.py
new file mode 100644
index 00000000..82ead439
--- /dev/null
+++ b/dm_control/viewer/util.py
@@ -0,0 +1,335 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Utility classes."""
+
+import collections
+import contextlib
+import itertools
+import sys
+import time
+import traceback
+
+from absl import logging
+
+# Lower bound of the time multiplier set through TimeMultiplier class.
+_MIN_TIME_MULTIPLIER = 1./32.
+# Upper bound of the time multiplier set through TimeMultiplier class.
+_MAX_TIME_MULTIPLIER = 1.
+
+
+def is_scalar(value):
+ """Checks if the supplied value can be converted to a scalar."""
+ try:
+ float(value)
+ except (TypeError, ValueError):
+ return False
+ else:
+ return True
+
+
+def to_iterable(item):
+ """Converts an item or iterable into an iterable."""
+ if isinstance(item, str):
+ return [item]
+ elif isinstance(item, collections.abc.Iterable):
+ return item
+ else:
+ return [item]
+
+
+class QuietSet:
+ """A set-like container that quietly processes removals of missing keys."""
+
+ def __init__(self):
+ self._items = set()
+
+ def __iadd__(self, items):
+ """Adds `items`, avoiding duplicates.
+
+ Args:
+ items: An iterable of items to add, or a single item to add.
+
+ Returns:
+ This instance of `QuietSet`.
+ """
+ self._items.update(to_iterable(items))
+ self._items.discard(None)
+ return self
+
+ def __isub__(self, items):
+ """Detaches `items`.
+
+ Args:
+ items: An iterable of items to detach, or a single item to detach.
+
+ Returns:
+ This instance of `QuietSet`.
+ """
+ for item in to_iterable(items):
+ self._items.discard(item)
+ return self
+
+ def __len__(self):
+ return len(self._items)
+
+ def __iter__(self):
+ return iter(self._items)
+
+
+def interleave(a, b):
+ """Interleaves the contents of two iterables."""
+ return itertools.chain.from_iterable(zip(a, b))
+
+
+class TimeMultiplier:
+ """Controls the relative speed of the simulation compared to realtime."""
+
+ def __init__(self, initial_time_multiplier):
+ """Instance initializer.
+
+ Args:
+ initial_time_multiplier: A float scalar specifying the initial speed of
+ the simulation with 1.0 corresponding to realtime.
+ """
+ self.set(initial_time_multiplier)
+
+ def get(self):
+ """Returns the current time factor value."""
+ return self._real_time_multiplier
+
+ def set(self, value):
+ """Modifies the time factor.
+
+ Args:
+ value: A float scalar, new value of the time factor.
+ """
+ self._real_time_multiplier = max(
+ _MIN_TIME_MULTIPLIER, min(_MAX_TIME_MULTIPLIER, value))
+
+ def __str__(self):
+ """Returns a formatted string containing the time factor."""
+ if self._real_time_multiplier >= 1.0:
+ time_factor = '%d' % self._real_time_multiplier
+ else:
+ time_factor = '1/%d' % (1.0 // self._real_time_multiplier)
+ return time_factor
+
+ def increase(self):
+ """Doubles the current time factor value."""
+ self.set(self._real_time_multiplier * 2.)
+
+ def decrease(self):
+ """Halves the current time factor value."""
+ self.set(self._real_time_multiplier / 2.)
+
+
+class Integrator:
+ """Integrates a value and averages it for the specified period of time."""
+
+ def __init__(self, refresh_rate=.5):
+ """Instance initializer.
+
+ Args:
+ refresh_rate: How often, in seconds, is the integrated value averaged.
+ """
+ self._value = 0
+ self._value_acc = 0
+ self._num_samples = 0
+ self._sampling_timestamp = time.time()
+ self._refresh_rate = refresh_rate
+
+ @property
+ def value(self):
+ """Returns the averaged value."""
+ return self._value
+
+ @value.setter
+ def value(self, val):
+ """Integrates the new value."""
+ self._value_acc += val
+ self._num_samples += 1
+
+ time_elapsed = time.time() - self._sampling_timestamp
+ if time_elapsed >= self._refresh_rate:
+ self._value = self._value_acc / self._num_samples
+ self._value_acc = 0
+ self._num_samples = 0
+ self._sampling_timestamp = time.time()
+
+
+class AtomicAction:
+ """An action that cannot be interrupted."""
+
+ def __init__(self, state_change_callback=None):
+ """Instance initializer.
+
+ Args:
+ state_change_callback: Callable invoked when action changes its state.
+ """
+ self._state_change_callback = state_change_callback
+ self._watermark = None
+
+ def begin(self, watermark):
+ """Begins the action, signing it with the specified watermark."""
+ if self._watermark is None:
+ self._watermark = watermark
+ if self._state_change_callback is not None:
+ self._state_change_callback(watermark)
+
+ def end(self, watermark):
+ """Ends a started action, provided the watermarks match."""
+ if self._watermark == watermark:
+ self._watermark = None
+ if self._state_change_callback is not None:
+ self._state_change_callback(None)
+
+ @property
+ def in_progress(self):
+ """Returns a boolean value to indicate if the being method was called."""
+ return self._watermark is not None
+
+ @property
+ def watermark(self):
+ """Returns the watermark passed to begin() method call, or None.
+
+ None will be returned if the action is not in progress.
+ """
+ return self._watermark
+
+
+class ObservableFlag(QuietSet):
+ """Observable boolean flag.
+
+ The QuietState provides necessary functionality for managing listeners.
+
+ A listener is a callable that takes one boolean parameter.
+ """
+
+ def __init__(self, initial_value):
+ """Instance initializer.
+
+ Args:
+ initial_value: A boolean value with the initial state of the flag.
+ """
+ self._value = initial_value
+ super().__init__()
+
+ def toggle(self):
+ """Toggles the value True/False."""
+ self._value = not self._value
+ for listener in self._items:
+ listener(self._value)
+
+ def __iadd__(self, value):
+ """Add new listeners and update them about the state."""
+ listeners = to_iterable(value)
+ super().__iadd__(listeners)
+ for listener in listeners:
+ listener(self._value)
+ return self
+
+ @property
+ def value(self):
+ """Value of the flag."""
+ return self._value
+
+ @value.setter
+ def value(self, val):
+ if self._value != val:
+ for listener in self._items:
+ listener(self._value)
+ self._value = val
+
+
+class Timer:
+ """Measures time elapsed between two ticks."""
+
+ def __init__(self):
+ """Instance initializer."""
+ self._previous_time = time.time()
+ self._measured_time = 0.
+
+ def tick(self):
+ """Updates the timer.
+
+ Returns:
+ Time elapsed since the last call to this method.
+ """
+ curr_time = time.time()
+ self._measured_time = curr_time - self._previous_time
+ self._previous_time = curr_time
+ return self._measured_time
+
+ @contextlib.contextmanager
+ def measure_time(self):
+ start_time = time.time()
+ yield
+ self._measured_time = time.time() - start_time
+
+ @property
+ def measured_time(self):
+ return self._measured_time
+
+
+class ErrorLogger:
+ """A context manager that catches and logs all errors."""
+
+ def __init__(self, listeners):
+ """Instance initializer.
+
+ Args:
+ listeners: An iterable of callables, listeners to inform when an error
+ is caught. Each callable should accept a single string argument.
+ """
+ self._error_found = False
+ self._listeners = listeners
+
+ def __enter__(self, *args):
+ self._error_found = False
+
+ def __exit__(self, exception_type, exception_value, tb):
+ if exception_type:
+ self._error_found = True
+ error_message = ('dm_control viewer intercepted an environment error.\n'
+ 'Original message: {}'.format(exception_value))
+ logging.error(error_message)
+ sys.stderr.write(error_message + '\nTraceback:\n')
+ traceback.print_tb(tb)
+ for listener in self._listeners:
+ listener('{}'.format(exception_value))
+ return True
+
+ @property
+ def errors_found(self):
+ """Returns True if any errors were caught."""
+ return self._error_found
+
+
+class NullErrorLogger:
+ """A context manager that replaces an ErrorLogger.
+
+ This error logger will pass all thrown errors through.
+ """
+
+ def __enter__(self, *args):
+ pass
+
+ def __exit__(self, error_type, value, tb):
+ pass
+
+ @property
+ def errors_found(self):
+ """Returns True if any errors were caught."""
+ return False
diff --git a/dm_control/viewer/util_test.py b/dm_control/viewer/util_test.py
new file mode 100644
index 00000000..49de0c25
--- /dev/null
+++ b/dm_control/viewer/util_test.py
@@ -0,0 +1,319 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the keyboard module."""
+
+import collections
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.viewer import util
+import mock
+import numpy as np
+
+
+class QuietSetTest(absltest.TestCase):
+
+ def test_add_listeners(self):
+ subject = util.QuietSet()
+ listeners = [object() for _ in range(5)]
+ for listener in listeners:
+ subject += listener
+ self.assertLen(subject, 5)
+
+ def test_add_collection_of_listeners(self):
+ subject = util.QuietSet()
+ subject += [object() for _ in range(5)]
+ self.assertLen(subject, 5)
+
+ def test_add_collection_and_individual_listeners(self):
+ subject = util.QuietSet()
+ subject += object()
+ subject += [object() for _ in range(5)]
+ subject += object()
+ self.assertLen(subject, 7)
+
+ def test_add_duplicate_listeners(self):
+ subject = util.QuietSet()
+ listener = object()
+ subject += listener
+ self.assertLen(subject, 1)
+ subject += listener
+ self.assertLen(subject, 1)
+
+ def test_remove_listeners(self):
+ subject = util.QuietSet()
+ listeners = [object() for _ in range(3)]
+ for listener in listeners:
+ subject += listener
+
+ subject -= listeners[1]
+ self.assertLen(subject, 2)
+
+ def test_remove_unregistered_listener(self):
+ subject = util.QuietSet()
+ listeners = [object() for _ in range(3)]
+ for listener in listeners:
+ subject += listener
+
+ subject -= object()
+ self.assertLen(subject, 3)
+
+
+class ToIterableTest(parameterized.TestCase):
+
+ def test_scalars_converted_to_iterables(self):
+ original_value = 3
+
+ result = util.to_iterable(original_value)
+ self.assertIsInstance(result, collections.abc.Iterable)
+ self.assertLen(result, 1)
+ self.assertEqual(original_value, result[0])
+
+ def test_strings_wrappe_by_list(self):
+ original_value = 'test_string'
+
+ result = util.to_iterable(original_value)
+ self.assertIsInstance(result, collections.abc.Iterable)
+ self.assertLen(result, 1)
+ self.assertEqual(original_value, result[0])
+
+ @parameterized.named_parameters(
+ ('list', [1, 2, 3]),
+ ('set', set([1, 2, 3])),
+ ('dict', {'1': 2, '3': 4, '5': 6})
+ )
+ def test_iterables_remain_unaffected(self, original_value):
+ result = util.to_iterable(original_value)
+ self.assertEqual(result, original_value)
+
+
+class InterleaveTest(absltest.TestCase):
+
+ def test_equal_sized_iterables(self):
+ a = [1, 2, 3]
+ b = [4, 5, 6]
+ c = [i for i in util.interleave(a, b)]
+ np.testing.assert_array_equal([1, 4, 2, 5, 3, 6], c)
+
+ def test_iteration_ends_when_smaller_iterable_runs_out_of_elements(self):
+ a = [1, 2, 3]
+ b = [4, 5, 6, 7, 8]
+ c = [i for i in util.interleave(a, b)]
+ np.testing.assert_array_equal([1, 4, 2, 5, 3, 6], c)
+
+
+class TimeMultiplierTests(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.factor = util.TimeMultiplier(initial_time_multiplier=1.0)
+
+ def test_custom_initial_factor(self):
+ initial_value = 0.5
+ factor = util.TimeMultiplier(initial_time_multiplier=initial_value)
+ self.assertEqual(initial_value, factor.get())
+
+ def test_initial_factor_clamped_to_valid_value_range(self):
+ too_large_multiplier = util._MAX_TIME_MULTIPLIER + 1.
+ too_small_multiplier = util._MIN_TIME_MULTIPLIER - 1.
+
+ factor = util.TimeMultiplier(initial_time_multiplier=too_large_multiplier)
+ self.assertEqual(util._MAX_TIME_MULTIPLIER, factor.get())
+
+ factor = util.TimeMultiplier(initial_time_multiplier=too_small_multiplier)
+ self.assertEqual(util._MIN_TIME_MULTIPLIER, factor.get())
+
+ def test_increase(self):
+ self.factor.decrease()
+ self.factor.decrease()
+ self.factor.increase()
+ self.assertEqual(self.factor._real_time_multiplier, 0.5)
+
+ def test_increase_limit(self):
+ self.factor._real_time_multiplier = util._MAX_TIME_MULTIPLIER
+ self.factor.increase()
+ self.assertEqual(util._MAX_TIME_MULTIPLIER, self.factor.get())
+
+ def test_decrease(self):
+ self.factor.decrease()
+ self.factor.decrease()
+ self.assertEqual(self.factor._real_time_multiplier, 0.25)
+
+ def test_decrease_limit(self):
+ self.factor._real_time_multiplier = util._MIN_TIME_MULTIPLIER
+ self.factor.decrease()
+ self.assertEqual(util._MIN_TIME_MULTIPLIER, self.factor.get())
+
+ def test_stringify_when_less_than_one(self):
+ self.assertEqual('1', str(self.factor))
+ self.factor.decrease()
+ self.assertEqual('1/2', str(self.factor))
+ self.factor.decrease()
+ self.assertEqual('1/4', str(self.factor))
+
+
+class IntegratorTests(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.integration_step = 1
+ self.integrator = util.Integrator(self.integration_step)
+ self.integrator._sampling_timestamp = 0.0
+
+ def test_initial_value(self):
+ self.assertEqual(0, self.integrator.value)
+
+ def test_integration_step(self):
+ with mock.patch(util.__name__ + '.time') as time_mock:
+ time_mock.time.return_value = self.integration_step
+ self.integrator.value = 1
+ self.assertEqual(1, self.integrator.value)
+
+ def test_averaging(self):
+ with mock.patch(util.__name__ + '.time') as time_mock:
+ time_mock.time.return_value = 0
+ self.integrator.value = 1
+ self.integrator.value = 1
+ self.integrator.value = 1
+ time_mock.time.return_value = self.integration_step
+ self.integrator.value = 1
+ self.assertEqual(1, self.integrator.value)
+
+
+class AtomicActionTests(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.callback = mock.MagicMock()
+ self.action = util.AtomicAction(self.callback)
+
+ def test_starting_and_ending_one_action(self):
+ self.action.begin(1)
+ self.assertEqual(1, self.action.watermark)
+ self.callback.assert_called_once_with(1)
+
+ self.callback.reset_mock()
+
+ self.action.end(1)
+ self.assertIsNone(self.action.watermark)
+ self.callback.assert_called_once_with(None)
+
+ def test_trying_to_interrupt_with_another_action(self):
+ self.action.begin(1)
+ self.assertEqual(1, self.action.watermark)
+ self.callback.assert_called_once_with(1)
+
+ self.callback.reset_mock()
+
+ self.action.begin(2)
+ self.assertEqual(1, self.action.watermark)
+ self.assertEqual(0, self.callback.call_count)
+
+ def test_trying_to_end_another_action(self):
+ self.action.begin(1)
+ self.callback.reset_mock()
+
+ self.action.end(2)
+ self.assertEqual(1, self.action.watermark)
+ self.assertEqual(0, self.callback.call_count)
+
+
+class ObservableFlagTest(absltest.TestCase):
+
+ def test_update_each_added_listener(self):
+ listener = mock.MagicMock(spec=object)
+
+ subject = util.ObservableFlag(True)
+
+ subject += listener
+ listener.assert_called_once_with(True)
+
+ def test_update_listeners_on_toggle(self):
+ listeners = [mock.MagicMock(spec=object) for _ in range(10)]
+
+ subject = util.ObservableFlag(True)
+ subject += listeners
+
+ for listener in listeners:
+ listener.reset_mock()
+ subject.toggle()
+ for listener in listeners:
+ listener.assert_called_once_with(False)
+
+
+class TimerTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.timer = util.Timer()
+
+ def test_time_elapsed(self):
+ with mock.patch(util.__name__ + '.time') as time_mock:
+ time_mock.time.return_value = 1
+ self.timer.tick()
+ time_mock.time.return_value = 2
+ self.assertEqual(1, self.timer.tick())
+
+ def test_time_measurement(self):
+ with mock.patch(util.__name__ + '.time') as time_mock:
+ time_mock.time.return_value = 1
+ with self.timer.measure_time():
+ time_mock.time.return_value = 4
+ self.assertEqual(3, self.timer.measured_time)
+
+
+class ErrorLoggerTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.callback = mock.MagicMock()
+ self.logger = util.ErrorLogger([self.callback])
+
+ def test_no_errors_found_on_initialization(self):
+ self.assertFalse(self.logger.errors_found)
+
+ def test_no_error_caught(self):
+ with self.logger:
+ pass
+ self.assertFalse(self.logger.errors_found)
+
+ def test_error_caught(self):
+ with self.logger:
+ raise Exception('error message')
+ self.assertTrue(self.logger.errors_found)
+
+ def test_notifying_callbacks(self):
+ error_message = 'error message'
+ with self.logger:
+ raise Exception(error_message)
+ self.callback.assert_called_once_with(error_message)
+
+
+class NullErrorLoggerTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.logger = util.NullErrorLogger()
+
+ def test_thrown_errors_are_not_being_intercepted(self):
+ with self.assertRaises(Exception):
+ with self.logger:
+ raise Exception()
+
+ def test_errors_found_always_returns_false(self):
+ self.assertFalse(self.logger.errors_found)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/viewer.py b/dm_control/viewer/viewer.py
new file mode 100644
index 00000000..583519ab
--- /dev/null
+++ b/dm_control/viewer/viewer.py
@@ -0,0 +1,546 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Mujoco Physics viewer, with custom input controllers."""
+
+
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.viewer import renderer
+from dm_control.viewer import user_input
+from dm_control.viewer import util
+import mujoco
+
+functions = mjbindings.functions
+
+_NUM_GROUP_KEYS = 10
+
+_PAN_CAMERA_VERTICAL_MOUSE = user_input.Exclusive(
+ user_input.MOUSE_BUTTON_RIGHT)
+_PAN_CAMERA_HORIZONTAL_MOUSE = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_RIGHT, user_input.MOD_SHIFT))
+_ROTATE_OBJECT_MOUSE = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_LEFT, user_input.MOD_CONTROL))
+_MOVE_OBJECT_VERTICAL_MOUSE = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_RIGHT, user_input.MOD_CONTROL))
+_MOVE_OBJECT_HORIZONTAL_MOUSE = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_RIGHT, user_input.MOD_SHIFT_CONTROL))
+
+_PAN_CAMERA_VERTICAL_TOUCHPAD = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_LEFT, user_input.MOD_ALT))
+_PAN_CAMERA_HORIZONTAL_TOUCHPAD = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_RIGHT, user_input.MOD_ALT))
+_ROTATE_OBJECT_TOUCHPAD = user_input.Exclusive(
+ user_input.MOUSE_BUTTON_RIGHT)
+_MOVE_OBJECT_VERTICAL_TOUCHPAD = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_LEFT, user_input.MOD_CONTROL))
+_MOVE_OBJECT_HORIZONTAL_TOUCHPAD = user_input.Exclusive(
+ (user_input.MOUSE_BUTTON_LEFT, user_input.MOD_SHIFT_CONTROL))
+
+_ROTATE_CAMERA = user_input.Exclusive(user_input.MOUSE_BUTTON_LEFT)
+_CENTER_CAMERA = user_input.DoubleClick(user_input.MOUSE_BUTTON_RIGHT)
+_SELECT_OBJECT = user_input.DoubleClick(user_input.MOUSE_BUTTON_LEFT)
+_TRACK_OBJECT = user_input.DoubleClick(
+ (user_input.MOUSE_BUTTON_RIGHT, user_input.MOD_CONTROL))
+_FREE_LOOK = user_input.KEY_ESCAPE
+_NEXT_CAMERA = user_input.KEY_RIGHT_BRACKET
+_PREVIOUS_CAMERA = user_input.KEY_LEFT_BRACKET
+_ZOOM_TO_SCENE = (user_input.KEY_A, user_input.MOD_CONTROL)
+_DOUBLE_BUFFERING = user_input.KEY_F5
+_PREV_RENDERING_MODE = (user_input.KEY_F6, user_input.MOD_SHIFT)
+_NEXT_RENDERING_MODE = user_input.KEY_F6
+_PREV_LABELING_MODE = (user_input.KEY_F7, user_input.MOD_SHIFT)
+_NEXT_LABELING_MODE = user_input.KEY_F7
+_PRINT_CAMERA = user_input.KEY_F11
+_VISUALIZATION_FLAGS = user_input.Range([
+ ord(functions.mjVISSTRING[i][2]) if functions.mjVISSTRING[i][2] else 0
+ for i in range(0, mujoco.mjtVisFlag.mjNVISFLAG)
+])
+_GEOM_GROUPS = user_input.Range(
+ [i + ord('0') for i in range(min(_NUM_GROUP_KEYS, mujoco.mjNGROUP))])
+_SITE_GROUPS = user_input.Range([
+ (i + ord('0'), user_input.MOD_SHIFT)
+ for i in range(min(_NUM_GROUP_KEYS, mujoco.mjNGROUP))
+])
+_RENDERING_FLAGS = user_input.Range([
+ ord(functions.mjRNDSTRING[i][2]) if functions.mjRNDSTRING[i][2] else 0
+ for i in range(0, mujoco.mjtRndFlag.mjNRNDFLAG)
+])
+
+_CAMERA_MOVEMENT_ACTIONS = [
+ mujoco.mjtMouse.mjMOUSE_MOVE_V, mujoco.mjtMouse.mjMOUSE_ROTATE_H
+]
+
+# Translates mouse wheel rotations to zoom speed.
+_SCROLL_SPEED_FACTOR = 0.05
+
+# Distance, in meters, at which to focus on the clicked object.
+_LOOK_AT_DISTANCE = 1.5
+
+# Zoom factor used when zooming in on the entire scene.
+_FULL_SCENE_ZOOM_FACTOR = 1.5
+
+
+class Viewer:
+ """Viewport displaying the contents of a physics world."""
+
+ def __init__(self, viewport, mouse, keyboard, camera_settings=None,
+ zoom_factor=_FULL_SCENE_ZOOM_FACTOR, scene_callback=None):
+ """Instance initializer.
+
+ Args:
+ viewport: Render viewport, instance of renderer.Viewport.
+ mouse: A mouse device.
+ keyboard: A keyboard device.
+ camera_settings: Properties of the scene MjvCamera.
+ zoom_factor: Initial scale factor for zooming into the scene.
+ scene_callback: Scene callback.
+ This is a callable of the form: `my_callable(MjModel, MjData, MjvScene)`
+ that gets applied to every rendered scene.
+ """
+ self._viewport = viewport
+ self._mouse = mouse
+
+ self._null_perturbation = renderer.NullPerturbation()
+ self._render_settings = renderer.RenderSettings()
+ self._input_map = user_input.InputMap(mouse, keyboard)
+
+ self._camera = None
+ self._camera_settings = camera_settings
+ self._renderer = None
+ self._manipulator = None
+ self._free_camera = None
+ self._camera_select = None
+ self._zoom_factor = zoom_factor
+ self._scene_callback = scene_callback
+
+ def __del__(self):
+ del self._camera
+ del self._renderer
+ del self._manipulator
+ del self._free_camera
+ del self._camera_select
+
+ def initialize(self, physics, renderer_instance, touchpad):
+ """Initialize the viewer.
+
+ Args:
+ physics: Physics instance.
+ renderer_instance: A renderer.Base instance.
+ touchpad: A boolean, use input dedicated to touchpad.
+ """
+ self._camera = renderer.SceneCamera(
+ physics.model,
+ physics.data,
+ self._render_settings,
+ settings=self._camera_settings,
+ zoom_factor=self._zoom_factor,
+ scene_callback=self._scene_callback)
+
+ self._manipulator = ManipulationController(
+ self._viewport, self._camera, self._mouse)
+
+ self._free_camera = FreeCameraController(
+ self._viewport, self._camera, self._mouse, self._manipulator)
+
+ self._camera_select = CameraSelector(
+ physics.model, self._camera, self._free_camera)
+
+ self._renderer = renderer_instance
+
+ self._input_map.clear_bindings()
+
+ if touchpad:
+ self._input_map.bind(
+ self._manipulator.set_move_vertical_mode,
+ _MOVE_OBJECT_VERTICAL_TOUCHPAD)
+ self._input_map.bind(
+ self._manipulator.set_move_horizontal_mode,
+ _MOVE_OBJECT_HORIZONTAL_TOUCHPAD)
+ self._input_map.bind(
+ self._manipulator.set_rotate_mode, _ROTATE_OBJECT_TOUCHPAD)
+ self._input_map.bind(
+ self._free_camera.set_pan_vertical_mode,
+ _PAN_CAMERA_VERTICAL_TOUCHPAD)
+ self._input_map.bind(
+ self._free_camera.set_pan_horizontal_mode,
+ _PAN_CAMERA_HORIZONTAL_TOUCHPAD)
+ else:
+ self._input_map.bind(
+ self._manipulator.set_move_vertical_mode, _MOVE_OBJECT_VERTICAL_MOUSE)
+ self._input_map.bind(
+ self._manipulator.set_move_horizontal_mode,
+ _MOVE_OBJECT_HORIZONTAL_MOUSE)
+ self._input_map.bind(
+ self._manipulator.set_rotate_mode, _ROTATE_OBJECT_MOUSE)
+ self._input_map.bind(
+ self._free_camera.set_pan_vertical_mode, _PAN_CAMERA_VERTICAL_MOUSE)
+ self._input_map.bind(
+ self._free_camera.set_pan_horizontal_mode,
+ _PAN_CAMERA_HORIZONTAL_MOUSE)
+
+ self._input_map.bind(self._print_camera_transform, _PRINT_CAMERA)
+ self._input_map.bind(
+ self._render_settings.select_prev_rendering_mode, _PREV_RENDERING_MODE)
+ self._input_map.bind(
+ self._render_settings.select_next_rendering_mode, _NEXT_RENDERING_MODE)
+ self._input_map.bind(
+ self._render_settings.select_prev_labeling_mode, _PREV_LABELING_MODE)
+ self._input_map.bind(
+ self._render_settings.select_next_labeling_mode, _NEXT_LABELING_MODE)
+ self._input_map.bind(
+ self._render_settings.select_prev_labeling_mode, _PREV_LABELING_MODE)
+ self._input_map.bind(
+ self._render_settings.toggle_stereo_buffering, _DOUBLE_BUFFERING)
+ self._input_map.bind(
+ self._render_settings.toggle_visualization_flag, _VISUALIZATION_FLAGS)
+ self._input_map.bind(
+ self._render_settings.toggle_site_group, _SITE_GROUPS)
+ self._input_map.bind(
+ self._render_settings.toggle_geom_group, _GEOM_GROUPS)
+ self._input_map.bind(
+ self._render_settings.toggle_rendering_flag, _RENDERING_FLAGS)
+
+ self._input_map.bind(self._camera.zoom_to_scene, _ZOOM_TO_SCENE)
+ self._input_map.bind(self._camera_select.select_next, _NEXT_CAMERA)
+ self._input_map.bind(self._camera_select.select_previous, _PREVIOUS_CAMERA)
+ self._input_map.bind_z_axis(self._free_camera.zoom)
+ self._input_map.bind_plane(self._free_camera.on_move)
+ self._input_map.bind(self._free_camera.set_rotate_mode, _ROTATE_CAMERA)
+ self._input_map.bind(self._free_camera.center, _CENTER_CAMERA)
+ self._input_map.bind(self._free_camera.track, _TRACK_OBJECT)
+ self._input_map.bind(self._camera_select.escape, _FREE_LOOK)
+ self._input_map.bind(self._manipulator.select, _SELECT_OBJECT)
+ self._input_map.bind_plane(self._manipulator.on_move)
+
+ def deinitialize(self):
+ """Deinitializes the viewer instance."""
+ self._input_map.clear_bindings()
+ self._camera_settings = self._camera.settings if self._camera else None
+ del self._camera
+ del self._renderer
+ del self._manipulator
+ del self._free_camera
+ del self._camera_select
+ self._camera = None
+ self._renderer = None
+ self._manipulator = None
+ self._free_camera = None
+ self._camera_select = None
+
+ def render(self):
+ """Renders the visualized scene."""
+ if self._camera and self._renderer: # Can be None during env reload.
+ scene = self._camera.render(self.perturbation)
+ self._render_settings.apply_settings(scene)
+ self._renderer.render(self._viewport, scene)
+
+ def zoom_to_scene(self):
+ """Utility method that set the camera to embrace the entire scene."""
+ if self._camera:
+ self._camera.zoom_to_scene()
+
+ def _print_camera_transform(self):
+ if self._camera:
+ rotation_mtx, position = self._camera.transform
+ right, up, _ = rotation_mtx
+ print('' % (
+ position[0], position[1], position[2], right[0], right[1],
+ right[2], up[0], up[1], up[2]))
+
+ @property
+ def perturbation(self):
+ """Returns an active renderer.Perturbation object."""
+ if self._manipulator and self._manipulator.perturbation:
+ return self._manipulator.perturbation
+ else:
+ return self._null_perturbation
+
+ @property
+ def camera(self):
+ """Returns an active renderer.SceneCamera instance."""
+ return self._camera
+
+ @property
+ def render_settings(self):
+ """Returns renderer.RenderSettings used by this viewer."""
+ return self._render_settings
+
+
+class CameraSelector:
+ """Binds camera behavior to user input."""
+
+ def __init__(self, model, camera, free_camera, **unused):
+ """Instance initializer.
+
+ Args:
+ model: Instance of MjModel.
+ camera: Instance of SceneCamera.
+ free_camera: Instance of FreeCameraController.
+ **unused: Other arguments, not used by this class.
+ """
+ del unused # Unused.
+ self._model = model
+ self._camera = camera
+ self._free_ctrl = free_camera
+
+ self._camera_idx = -1
+ self._active_ctrl = self._free_ctrl
+
+ def select_previous(self):
+ """Cycles to the previous scene camera."""
+ self._camera_idx -= 1
+ if not self._model.ncam or self._camera_idx < -1:
+ self._camera_idx = self._model.ncam - 1
+ self._commit_selection()
+
+ def select_next(self):
+ """Cycles to the next scene camera."""
+ self._camera_idx += 1
+ if not self._model.ncam or self._camera_idx >= self._model.ncam:
+ self._camera_idx = -1
+ self._commit_selection()
+
+ def escape(self) -> None:
+ """Unconditionally switches to the free camera."""
+ self._camera_idx = -1
+ self._commit_selection()
+
+ def _commit_selection(self):
+ """Selects a controller that should go with the selected camera."""
+ if self._camera_idx < 0:
+ self._activate(self._free_ctrl)
+ else:
+ self._camera.set_fixed_mode(self._camera_idx)
+ self._activate(None)
+
+ def _activate(self, controller):
+ """Activates a sub-controller."""
+ if controller == self._active_ctrl:
+ return
+
+ if self._active_ctrl is not None:
+ self._active_ctrl.deactivate()
+ self._active_ctrl = controller
+ if self._active_ctrl is not None:
+ self._active_ctrl.activate()
+
+
+class FreeCameraController:
+ """Implements the free camera behavior."""
+
+ def __init__(self, viewport, camera, pointer, selection_service, **unused):
+ """Instance initializer.
+
+ Args:
+ viewport: Instance of mujoco_viewer.Viewport.
+ camera: Instance of mujoco_viewer.SceneCamera.
+ pointer: A pointer that moves around the screen and is used to point at
+ bodies. Implements a single attribute - 'position' - that returns a
+ 2-component vector of pointer's screen space position.
+ selection_service: An instance of a class implementing a
+ 'selected_body_id' property.
+ **unused: Other optional parameters not used by this class.
+ """
+ del unused # Unused.
+ self._viewport = viewport
+ self._camera = camera
+ self._pointer = pointer
+ self._selection_service = selection_service
+ self._active = True
+ self._tracked_body_idx = -1
+ self._action = util.AtomicAction()
+
+ def activate(self):
+ """Activates the controller."""
+ self._active = True
+ self._update_camera_mode()
+
+ def deactivate(self):
+ """Deactivates the controller."""
+ self._active = False
+ self._action = util.AtomicAction()
+
+ def set_pan_vertical_mode(self, enable):
+ """Starts/ends the camera panning action along the vertical plane.
+
+ Args:
+ enable: A boolean flag, True to start the action, False to end it.
+ """
+ if self._active:
+ if enable:
+ self._action.begin(mujoco.mjtMouse.mjMOUSE_MOVE_V)
+ else:
+ self._action.end(mujoco.mjtMouse.mjMOUSE_MOVE_V)
+
+ def set_pan_horizontal_mode(self, enable):
+ """Starts/ends the camera panning action along the horizontal plane.
+
+ Args:
+ enable: A boolean flag, True to start the action, False to end it.
+ """
+ if self._active:
+ if enable:
+ self._action.begin(mujoco.mjtMouse.mjMOUSE_MOVE_H)
+ else:
+ self._action.end(mujoco.mjtMouse.mjMOUSE_MOVE_H)
+
+ def set_rotate_mode(self, enable):
+ """Starts/ends the camera rotation action.
+
+ Args:
+ enable: A boolean flag, True to start the action, False to end it.
+ """
+ if self._active:
+ if enable:
+ self._action.begin(mujoco.mjtMouse.mjMOUSE_ROTATE_H)
+ else:
+ self._action.end(mujoco.mjtMouse.mjMOUSE_ROTATE_H)
+
+ def center(self):
+ """Focuses camera on the object the pointer is currently pointing at."""
+ if self._active:
+ body_id, world_pos = self._camera.raycast(self._viewport,
+ self._pointer.position)
+ if body_id >= 0:
+ self._camera.look_at(world_pos, _LOOK_AT_DISTANCE)
+
+ def on_move(self, position, translation):
+ """Translates mouse moves onto camera movements."""
+ del position
+ if self._action.in_progress:
+ viewport_offset = self._viewport.screen_to_viewport(translation)
+ self._camera.move(self._action.watermark, viewport_offset)
+
+ def zoom(self, zoom_factor):
+ """Zooms the camera in/out.
+
+ Args:
+ zoom_factor: A floating point value, by how much to zoom the camera.
+ Positive values zoom the camera in, negative values zoom it out.
+ """
+ if self._active:
+ offset = [0, _SCROLL_SPEED_FACTOR * zoom_factor * -1.]
+ self._camera.move(mujoco.mjtMouse.mjMOUSE_ZOOM, offset)
+
+ def track(self):
+ """Makes the camera track the currently selected object.
+
+ The selection is managed by the selection service.
+ """
+ if self._active and self._tracked_body_idx < 0:
+ self._tracked_body_idx = self._selection_service.selected_body_id
+ self._update_camera_mode()
+
+ def free_look(self):
+ """Switches the camera to a free-look mode."""
+ if self._active:
+ self._tracked_body_idx = -1
+ self._update_camera_mode()
+
+ def _update_camera_mode(self):
+ """Sets the camera into a tracking or a free-look mode."""
+ if self._tracked_body_idx >= 0:
+ self._camera.set_tracking_mode(self._tracked_body_idx)
+ else:
+ self._camera.set_freelook_mode()
+
+
+class ManipulationController:
+ """Binds control over scene objects to user input."""
+
+ def __init__(self, viewport, camera, pointer, **unused):
+ """Instance initializer.
+
+ Args:
+ viewport: Instance of mujoco_viewer.Viewport.
+ camera: Instance of mujoco_viewer.SceneCamera.
+ pointer: A pointer that moves around the screen and is used to point at
+ bodies. Implements a single attribute - 'position' - that returns a
+ 2-component vector of pointer's screen space position.
+ **unused: Other arguments, unused by this class.
+ """
+ del unused # Unused.
+ self._viewport = viewport
+ self._camera = camera
+ self._pointer = pointer
+ self._action = util.AtomicAction(self._update_action)
+ self._perturb = None
+
+ def select(self):
+ """Translates mouse double-clicks to object selection action."""
+ body_id, _ = self._camera.raycast(self._viewport, self._pointer.position)
+ if body_id >= 0:
+ self._perturb = self._camera.new_perturbation(body_id)
+ else:
+ self._perturb = None
+
+ def set_move_vertical_mode(self, enable):
+ """Begins/ends an object translation action along the vertical plane.
+
+ Args:
+ enable: A boolean flag, True begins the action, False ends it.
+ """
+ if enable:
+ self._action.begin(mujoco.mjtMouse.mjMOUSE_MOVE_V)
+ else:
+ self._action.end(mujoco.mjtMouse.mjMOUSE_MOVE_V)
+
+ def set_move_horizontal_mode(self, enable):
+ """Begins/ends an object translation action along the horizontal plane.
+
+ Args:
+ enable: A boolean flag, True begins the action, False ends it.
+ """
+ if enable:
+ self._action.begin(mujoco.mjtMouse.mjMOUSE_MOVE_H)
+ else:
+ self._action.end(mujoco.mjtMouse.mjMOUSE_MOVE_H)
+
+ def set_rotate_mode(self, enable):
+ """Begins/ends an object rotation action.
+
+ Args:
+ enable: A boolean flag, True begins the action, False ends it.
+ """
+ if enable:
+ self._action.begin(mujoco.mjtMouse.mjMOUSE_ROTATE_H)
+ else:
+ self._action.end(mujoco.mjtMouse.mjMOUSE_ROTATE_H)
+
+ def _update_action(self, action):
+ if self._perturb is not None:
+ if action is not None:
+ _, grab_pos = self._camera.raycast(self._viewport,
+ self._pointer.position)
+ self._perturb.start_move(action, grab_pos)
+ else:
+ self._perturb.end_move()
+
+ def on_move(self, position, translation):
+ """Translates mouse moves to selected object movements."""
+ del position
+ if self._perturb is not None and self._action.in_progress:
+ viewport_offset = self._viewport.screen_to_viewport(translation)
+ self._perturb.tick_move(viewport_offset)
+
+ @property
+ def perturbation(self):
+ """Returns the Perturbation object that represents the manipulated body."""
+ return self._perturb
+
+ @property
+ def selected_body_id(self):
+ """Returns the id of the selected body, or -1 if none is selected."""
+ return self._perturb.body_id if self._perturb is not None else -1
diff --git a/dm_control/viewer/viewer_test.py b/dm_control/viewer/viewer_test.py
new file mode 100644
index 00000000..87692c02
--- /dev/null
+++ b/dm_control/viewer/viewer_test.py
@@ -0,0 +1,479 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests of the viewer.py module."""
+
+
+from absl.testing import absltest
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.viewer import util
+from dm_control.viewer import viewer
+import mock
+
+
+class ViewerTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.viewport = mock.MagicMock()
+ self.mouse = mock.MagicMock()
+ self.keyboard = mock.MagicMock()
+ self.viewer = viewer.Viewer(self.viewport, self.mouse, self.keyboard)
+
+ self.viewer._render_settings = mock.MagicMock()
+ self.physics = mock.MagicMock()
+ self.renderer = mock.MagicMock()
+ self.renderer.priority_components = util.QuietSet()
+
+ def _extract_bind_call_args(self, bind_mock):
+ call_args = []
+ for calls in bind_mock.call_args_list:
+ args = calls[0]
+ if len(args) == 2:
+ call_args.append(args[1])
+ return call_args
+
+ def test_initialize_creates_components(self):
+ with mock.patch(viewer.__name__ + '.renderer'):
+ self.viewer.initialize(self.physics, self.renderer, touchpad=False)
+ self.assertIsNotNone(self.viewer._camera)
+ self.assertIsNotNone(self.viewer._manipulator)
+ self.assertIsNotNone(self.viewer._free_camera)
+ self.assertIsNotNone(self.viewer._camera_select)
+ self.assertEqual(self.renderer, self.viewer._renderer)
+
+ def test_initialize_creates_touchpad_specific_input_mapping(self):
+ self.viewer._input_map = mock.MagicMock()
+ with mock.patch(viewer.__name__ + '.renderer'):
+ self.viewer.initialize(self.physics, self.renderer, touchpad=True)
+ call_args = self._extract_bind_call_args(self.viewer._input_map.bind)
+ self.assertIn(viewer._MOVE_OBJECT_VERTICAL_TOUCHPAD, call_args)
+ self.assertIn(viewer._MOVE_OBJECT_HORIZONTAL_TOUCHPAD, call_args)
+ self.assertIn(viewer._ROTATE_OBJECT_TOUCHPAD, call_args)
+ self.assertIn(viewer._PAN_CAMERA_VERTICAL_TOUCHPAD, call_args)
+ self.assertIn(viewer._PAN_CAMERA_HORIZONTAL_TOUCHPAD, call_args)
+
+ def test_initialize_create_mouse_specific_input_mapping(self):
+ self.viewer._input_map = mock.MagicMock()
+ with mock.patch(viewer.__name__ + '.renderer'):
+ self.viewer.initialize(self.physics, self.renderer, touchpad=False)
+ call_args = self._extract_bind_call_args(self.viewer._input_map.bind)
+ self.assertIn(viewer._MOVE_OBJECT_VERTICAL_MOUSE, call_args)
+ self.assertIn(viewer._MOVE_OBJECT_HORIZONTAL_MOUSE, call_args)
+ self.assertIn(viewer._ROTATE_OBJECT_MOUSE, call_args)
+ self.assertIn(viewer._PAN_CAMERA_VERTICAL_MOUSE, call_args)
+ self.assertIn(viewer._PAN_CAMERA_HORIZONTAL_MOUSE, call_args)
+
+ def test_initialization_flushes_old_input_map(self):
+ self.viewer._input_map = mock.MagicMock()
+ with mock.patch(viewer.__name__ + '.renderer'):
+ self.viewer.initialize(self.physics, self.renderer, touchpad=False)
+ self.viewer._input_map.clear_bindings.assert_called_once()
+
+ def test_deinitialization_deletes_components(self):
+ self.viewer._camera = mock.MagicMock()
+ self.viewer._manipulator = mock.MagicMock()
+ self.viewer._free_camera = mock.MagicMock()
+ self.viewer._camera_select = mock.MagicMock()
+ self.viewer._renderer = mock.MagicMock()
+ self.viewer.deinitialize()
+ self.assertIsNone(self.viewer._camera)
+ self.assertIsNone(self.viewer._manipulator)
+ self.assertIsNone(self.viewer._free_camera)
+ self.assertIsNone(self.viewer._camera_select)
+ self.assertIsNone(self.viewer._renderer)
+
+ def test_deinitialization_flushes_old_input_map(self):
+ self.viewer._input_map = mock.MagicMock()
+ self.viewer.deinitialize()
+ self.viewer._input_map.clear_bindings.assert_called_once()
+
+ def test_rendering_uninitialized(self):
+ self.viewer.render() # nothing crashes
+
+ def test_zoom_to_scene_uninitialized(self):
+ self.viewer.zoom_to_scene() # nothing crashes
+
+ def test_rendering(self):
+ self.viewer._camera = mock.MagicMock()
+ self.viewer._renderer = mock.MagicMock()
+ self.viewer.render()
+ self.viewer._camera.render.assert_called_once_with(self.viewer.perturbation)
+ self.viewer._renderer.render.assert_called_once()
+
+ def test_applying_render_settings_before_rendering_a_scene(self):
+ self.viewer._camera = mock.MagicMock()
+ self.viewer._renderer = mock.MagicMock()
+ self.viewer.render()
+ self.viewer._render_settings.apply_settings.assert_called_once()
+
+ def test_zoom_to_scene(self):
+ self.viewer._camera = mock.MagicMock()
+ self.viewer.zoom_to_scene()
+ self.viewer._camera.zoom_to_scene.assert_called_once()
+
+ def test_retrieving_perturbation(self):
+ object_perturbation = mock.MagicMock()
+ self.viewer._manipulator = mock.MagicMock()
+ self.viewer._manipulator.perturbation = object_perturbation
+ self.assertEqual(object_perturbation, self.viewer.perturbation)
+
+ def test_retrieving_perturbation_without_manipulator(self):
+ self.viewer._manipulator = None
+ self.assertEqual(self.viewer._null_perturbation, self.viewer.perturbation)
+
+ def test_retrieving_perturbation_without_selected_object(self):
+ self.viewer._manipulator = mock.MagicMock()
+ self.viewer._manipulator.perturbation = None
+ self.assertEqual(self.viewer._null_perturbation, self.viewer.perturbation)
+
+
+class CameraSelectorTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.camera = mock.MagicMock()
+ self.model = mock.MagicMock()
+ self.free_camera = mock.MagicMock()
+ self.model.ncam = 2
+
+ options = {
+ 'camera': self.camera,
+ 'model': self.model,
+ 'free_camera': self.free_camera
+ }
+
+ self.controller = viewer.CameraSelector(**options)
+
+ def test_activating_freelook_camera_by_default(self):
+ self.assertEqual(self.controller._free_ctrl, self.controller._active_ctrl)
+
+ def test_cycling_forward_through_cameras(self):
+ self.controller.select_next()
+ self.assertIsNone(self.controller._active_ctrl)
+ self.controller._free_ctrl.deactivate.assert_called_once()
+ self.controller._free_ctrl.reset_mock()
+ self.controller._camera.set_fixed_mode.assert_called_once_with(0)
+ self.controller._camera.reset_mock()
+
+ self.controller.select_next()
+ self.assertIsNone(self.controller._active_ctrl)
+ self.controller._camera.set_fixed_mode.assert_called_once_with(1)
+ self.controller._camera.reset_mock()
+
+ self.controller.select_next()
+ self.assertEqual(self.controller._free_ctrl, self.controller._active_ctrl)
+ self.controller._free_ctrl.activate.assert_called_once()
+
+ def test_cycling_backwards_through_cameras(self):
+ self.controller.select_previous()
+ self.assertIsNone(self.controller._active_ctrl)
+ self.controller._free_ctrl.deactivate.assert_called_once()
+ self.controller._free_ctrl.reset_mock()
+ self.controller._camera.set_fixed_mode.assert_called_once_with(1)
+ self.controller._camera.reset_mock()
+
+ self.controller.select_previous()
+ self.assertIsNone(self.controller._active_ctrl)
+ self.controller._camera.set_fixed_mode.assert_called_once_with(0)
+ self.controller._camera.reset_mock()
+
+ self.controller.select_previous()
+ self.assertEqual(self.controller._free_ctrl, self.controller._active_ctrl)
+ self.controller._free_ctrl.activate.assert_called_once()
+
+ def test_controller_activation(self):
+ old_controller = mock.MagicMock()
+ new_controller = mock.MagicMock()
+ self.controller._active_ctrl = old_controller
+ self.controller._activate(new_controller)
+ old_controller.deactivate.assert_called_once()
+ new_controller.activate.assert_called_once()
+
+ def test_controller_activation_not_repeated_for_already_active_one(self):
+ controller = mock.MagicMock()
+ self.controller._active_ctrl = controller
+ self.controller._activate(controller)
+ self.assertEqual(0, controller.deactivate.call_count)
+ self.assertEqual(0, controller.activate.call_count)
+
+
+class FreeCameraControllerTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.viewport = mock.MagicMock()
+ self.camera = mock.MagicMock()
+ self.mouse = mock.MagicMock()
+ self.selection_service = mock.MagicMock()
+
+ options = {
+ 'camera': self.camera,
+ 'viewport': self.viewport,
+ 'pointer': self.mouse,
+ 'selection_service': self.selection_service
+ }
+
+ self.controller = viewer.FreeCameraController(**options)
+ self.controller._action = mock.MagicMock()
+
+ def test_activation_while_not_in_tracking_mode(self):
+ self.controller._tracked_body_idx = -1
+ self.controller.activate()
+ self.camera.set_freelook_mode.assert_called_once()
+
+ def test_activation_while_in_tracking_mode(self):
+ self.controller._tracked_body_idx = 1
+ self.controller.activate()
+ self.camera.set_tracking_mode.assert_called_once_with(1)
+
+ def test_activation_and_deactivation_flag(self):
+ self.controller.activate()
+ self.assertTrue(self.controller._active)
+ self.controller.deactivate()
+ self.assertFalse(self.controller._active)
+
+ def test_vertical_panning_camera_with_active_controller(self):
+ self.controller._active = True
+ self.controller.set_pan_vertical_mode(True)
+ self.controller._action.begin.assert_called_once_with(
+ enums.mjtMouse.mjMOUSE_MOVE_V)
+ self.controller.set_pan_vertical_mode(False)
+ self.controller._action.end.assert_called_once_with(
+ enums.mjtMouse.mjMOUSE_MOVE_V)
+
+ def test_vertical_panning_camera_with_inactive_controller(self):
+ self.controller._active = False
+ self.controller.set_pan_vertical_mode(True)
+ self.assertEqual(0, self.controller._action.begin.call_count)
+ self.controller.set_pan_vertical_mode(False)
+ self.assertEqual(0, self.controller._action.end.call_count)
+
+ def test_horizontal_panning_camera_with_active_controller(self):
+ self.controller._active = True
+ self.controller.set_pan_horizontal_mode(True)
+ self.controller._action.begin.assert_called_once_with(
+ enums.mjtMouse.mjMOUSE_MOVE_H)
+ self.controller.set_pan_horizontal_mode(False)
+ self.controller._action.end.assert_called_once_with(
+ enums.mjtMouse.mjMOUSE_MOVE_H)
+
+ def test_horizontal_panning_camera_with_inactive_controller(self):
+ self.controller._active = False
+ self.controller.set_pan_horizontal_mode(True)
+ self.assertEqual(0, self.controller._action.begin.call_count)
+ self.controller.set_pan_horizontal_mode(False)
+ self.assertEqual(0, self.controller._action.end.call_count)
+
+ def test_rotating_camera_with_active_controller(self):
+ self.controller._active = True
+ self.controller.set_rotate_mode(True)
+ self.controller._action.begin.assert_called_once_with(
+ enums.mjtMouse.mjMOUSE_ROTATE_H)
+ self.controller.set_rotate_mode(False)
+ self.controller._action.end.assert_called_once_with(
+ enums.mjtMouse.mjMOUSE_ROTATE_H)
+
+ def test_rotating_camera_with_inactive_controller(self):
+ self.controller._active = False
+ self.controller.set_rotate_mode(True)
+ self.assertEqual(0, self.controller._action.begin.call_count)
+ self.controller.set_rotate_mode(False)
+ self.assertEqual(0, self.controller._action.end.call_count)
+
+ def test_centering_with_active_controller(self):
+ self.controller._active = True
+ self.camera.raycast.return_value = 1, 2
+ self.controller.center()
+ self.camera.raycast.assert_called_once()
+
+ def test_centering_with_inactive_controller(self):
+ self.controller._active = False
+ self.controller.center()
+ self.assertEqual(0, self.camera.raycast.call_count)
+
+ def test_moving_mouse_moves_camera(self):
+ position = [100, 200]
+ translation = [1, 0]
+ viewport_space_translation = [2, 0]
+ action = 1
+ self.viewport.screen_to_viewport.return_value = viewport_space_translation
+ self.controller._action.in_progress = True
+ self.controller._action.watermark = action
+
+ self.controller.on_move(position, translation)
+ self.viewport.screen_to_viewport.assert_called_once_with(translation)
+ self.camera.move.assert_called_once_with(action, viewport_space_translation)
+
+ def test_mouse_move_doesnt_work_without_an_action_selected(self):
+ self.controller._action.in_progress = False
+ self.controller.on_move([], [])
+ self.assertEqual(0, self.camera.move.call_count)
+
+ def test_zoom_with_active_controller(self):
+ self.controller._active = True
+ expected_zoom_vector = [0, -0.05]
+ self.controller.zoom(1.)
+ self.camera.move.assert_called_once_with(
+ enums.mjtMouse.mjMOUSE_ZOOM, expected_zoom_vector)
+
+ def test_zoom_with_inactive_controller(self):
+ self.controller._active = False
+ self.controller.zoom(1.)
+ self.assertEqual(0, self.camera.move.call_count)
+
+ def test_tracking_with_active_controller(self):
+ self.controller._active = True
+ selected_body_id = 5
+ self.selection_service.selected_body_id = selected_body_id
+ self.controller._tracked_body_idx = -1
+
+ self.controller.track()
+ self.assertEqual(self.controller._tracked_body_idx, selected_body_id)
+ self.camera.set_tracking_mode.assert_called_once_with(selected_body_id)
+
+ def test_tracking_with_inactive_controller(self):
+ self.controller._active = False
+ selected_body_id = 5
+ self.selection_service.selected_body_id = selected_body_id
+ self.controller.track()
+ self.assertEqual(self.controller._tracked_body_idx, -1)
+ self.assertEqual(0, self.camera.set_tracking_mode.call_count)
+
+ def test_free_look_mode_with_active_controller(self):
+ self.controller._active = True
+ self.controller._tracked_body_idx = 5
+ self.controller.free_look()
+ self.assertEqual(self.controller._tracked_body_idx, -1)
+ self.camera.set_freelook_mode.assert_called_once()
+
+ def test_free_look_mode_with_inactive_controller(self):
+ self.controller._active = False
+ self.controller._tracked_body_idx = 5
+ self.controller.free_look()
+ self.assertEqual(self.controller._tracked_body_idx, 5)
+ self.assertEqual(0, self.camera.set_freelook_mode.call_count)
+
+
+class ManipulationControllerTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.viewport = mock.MagicMock()
+ self.camera = mock.MagicMock()
+ self.mouse = mock.MagicMock()
+
+ options = {
+ 'camera': self.camera,
+ 'viewport': self.viewport,
+ 'pointer': self.mouse,
+ }
+
+ self.controller = viewer.ManipulationController(**options)
+
+ self.body_id = 1
+ self.click_pos_on_body = [1, 2, 3]
+ self.camera.raycast.return_value = (self.body_id, self.click_pos_on_body)
+
+ def test_selecting_a_body(self):
+ self.camera.raycast.return_value = (self.body_id, self.click_pos_on_body)
+ self.controller.select()
+ self.assertIsNotNone(self.controller._perturb)
+
+ def test_selecting_empty_space_cancels_selection(self):
+ self.camera.raycast.return_value = (-1, None)
+ self.controller.select()
+ self.assertIsNone(self.controller._perturb)
+
+ def test_vertical_movement_operation(self):
+ self.controller._perturb = mock.MagicMock()
+
+ self.controller.set_move_vertical_mode(True)
+ self.controller._perturb.start_move.assert_called_once()
+ self.assertEqual(enums.mjtMouse.mjMOUSE_MOVE_V,
+ self.controller._perturb.start_move.call_args[0][0])
+
+ self.controller.set_move_vertical_mode(False)
+ self.controller._perturb.end_move.assert_called_once()
+
+ def test_horzontal_movement_operation(self):
+ self.controller._perturb = mock.MagicMock()
+
+ self.controller.set_move_horizontal_mode(True)
+ self.controller._perturb.start_move.assert_called_once()
+ self.assertEqual(enums.mjtMouse.mjMOUSE_MOVE_H,
+ self.controller._perturb.start_move.call_args[0][0])
+
+ self.controller.set_move_horizontal_mode(False)
+ self.controller._perturb.end_move.assert_called_once()
+
+ def test_rotation_operation(self):
+ self.controller._perturb = mock.MagicMock()
+
+ self.controller.set_rotate_mode(True)
+ self.controller._perturb.start_move.assert_called_once()
+ self.assertEqual(enums.mjtMouse.mjMOUSE_ROTATE_H,
+ self.controller._perturb.start_move.call_args[0][0])
+
+ self.controller.set_rotate_mode(False)
+ self.controller._perturb.end_move.assert_called_once()
+
+ def test_every_action_generates_a_fresh_grab_pos(self):
+ some_action = 0
+ self.controller._perturb = mock.MagicMock()
+ self.controller._update_action(some_action)
+ self.camera.raycast.assert_called_once()
+
+ def test_actions_not_started_without_object_selected(self):
+ some_action = 0
+ self.controller._perturb = None
+ self.controller._update_action(some_action)
+ self.assertEqual(0, self.camera.raycast.call_count)
+
+ def test_on_move_requires_an_action_to_be_started_first(self):
+ self.controller._perturb = mock.MagicMock()
+ self.controller._action = mock.MagicMock()
+ self.controller._action.in_progress = False
+
+ self.controller.on_move([], [])
+ self.assertEqual(0, self.controller._perturb.tick_move.call_count)
+
+ def test_dragging_selected_object_moves_it(self):
+ screen_pos = [1, 2]
+ screen_translation = [3, 4]
+ viewport_offset = [5, 6]
+ self.controller._perturb = mock.MagicMock()
+ self.controller._action = mock.MagicMock()
+ self.controller._action.in_progress = True
+ self.viewport.screen_to_viewport.return_value = viewport_offset
+
+ self.controller.on_move(screen_pos, screen_translation)
+ self.viewport.screen_to_viewport.assert_called_once_with(screen_translation)
+ self.controller._perturb.tick_move.assert_called_once_with(viewport_offset)
+
+ def test_operations_require_object_to_be_selected(self):
+ self.controller._perturb = None
+
+ # No exceptions should be raised.
+ self.controller.set_move_vertical_mode(True)
+ self.controller.set_move_vertical_mode(False)
+ self.controller.set_move_horizontal_mode(True)
+ self.controller.set_move_horizontal_mode(False)
+ self.controller.set_rotate_mode(True)
+ self.controller.set_rotate_mode(False)
+ self.controller.on_move([1, 2], [3, 4])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/viewer/views.py b/dm_control/viewer/views.py
new file mode 100644
index 00000000..a49d534e
--- /dev/null
+++ b/dm_control/viewer/views.py
@@ -0,0 +1,176 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Components and views that render custom images into Mujoco render frame."""
+
+import abc
+import enum
+
+from dm_control.viewer import renderer
+import mujoco
+import numpy as np
+
+
+class PanelLocation(enum.Enum):
+ TOP_LEFT = mujoco.mjtGridPos.mjGRID_TOPLEFT.value
+ TOP_RIGHT = mujoco.mjtGridPos.mjGRID_TOPRIGHT.value
+ BOTTOM_LEFT = mujoco.mjtGridPos.mjGRID_BOTTOMLEFT.value
+ BOTTOM_RIGHT = mujoco.mjtGridPos.mjGRID_BOTTOMRIGHT.value
+
+
+class BaseViewportView(metaclass=abc.ABCMeta):
+ """Base abstract view class."""
+
+ @abc.abstractmethod
+ def render(self, context, viewport, location):
+ """Renders the view on screen.
+
+ Args:
+ context: MjrContext instance.
+ viewport: Viewport instance.
+ location: Value defined in PanelLocation enum.
+ """
+ pass
+
+
+class ColumnTextModel(metaclass=abc.ABCMeta):
+ """Data model that returns 2 columns of text."""
+
+ @abc.abstractmethod
+ def get_columns(self):
+ """Returns the text to display in two columns.
+
+ Returns:
+ Returns an iterable of tuples of 2 strings. Each tuple has format
+ (left_column_label, right_column_label).
+ """
+ pass
+
+
+class ColumnTextView(BaseViewportView):
+ """A view displayed in Mujoco render window."""
+
+ def __init__(self, model):
+ """Instance initializer.
+
+ Args:
+ model: Instance of ColumnTextModel.
+ """
+ self._model = model
+
+ def render(self, context, viewport, location):
+ """Renders the overlay on screen.
+
+ Args:
+ context: MjrContext instance.
+ viewport: Viewport instance.
+ location: Value defined in PanelLocation enum.
+ """
+ columns = self._model.get_columns()
+ if not columns:
+ return
+
+ columns = np.asarray(columns)
+ left_column = '\n'.join(columns[:, 0])
+ right_column = '\n'.join(columns[:, 1])
+ mujoco.mjr_overlay(mujoco.mjtFont.mjFONT_NORMAL, location.value,
+ viewport.mujoco_rect, left_column, right_column,
+ context.ptr)
+
+
+class MujocoDepthBuffer(renderer.Component):
+ """Displays the contents of the scene's depth buffer."""
+
+ def __init__(self):
+ self._depth_buffer = np.zeros((1, 1), np.float32)
+
+ def render(self, context, viewport):
+ """Renders the overlay on screen.
+
+ Args:
+ context: MjrContext instance.
+ viewport: MJRRECT instance.
+ """
+ width_adjustment = viewport.width % 4
+ rect_shape = (viewport.width - width_adjustment, viewport.height)
+
+ if self._depth_buffer is None or self._depth_buffer.shape != rect_shape:
+ self._depth_buffer = np.zeros(
+ (viewport.width, viewport.height), np.float32)
+
+ mujoco.mjr_readPixels(None, self._depth_buffer, viewport.mujoco_rect,
+ context.ptr)
+
+ # Subsample by 4, convert to RGB, and cast to unsigned bytes.
+ depth_rgb = np.repeat(self._depth_buffer[::4, ::4, None] * 255, 3,
+ -1).astype(np.ubyte)
+
+ pos = mujoco.MjrRect(
+ int(3 * viewport.width / 4) + width_adjustment, 0,
+ int(viewport.width / 4), int(viewport.height / 4))
+ mujoco.mjr_drawPixels(depth_rgb, None, pos, context.ptr)
+
+
+class ViewportLayout(renderer.Component):
+ """Layout manager for the render viewport.
+
+ Allows to create a viewport layout by injecting renderer component even in
+ absence of a renderer, and then easily reattach it between renderers.
+ """
+
+ def __init__(self):
+ """Instance initializer."""
+ self._views = dict()
+
+ def __len__(self):
+ return len(self._views)
+
+ def __contains__(self, key):
+ value = self._views.get(key, None)
+ return value is not None
+
+ def add(self, view, location):
+ """Adds a new view.
+
+ Args:
+ view: renderer.BaseViewportView instance.
+ location: Value defined in PanelLocation enum, location of the view in the
+ viewport.
+ """
+ if not isinstance(view, BaseViewportView):
+ raise TypeError(
+ 'View added to this layout needs to implement BaseViewportView.')
+ self._views[view] = location
+
+ def remove(self, view):
+ """Removes a view.
+
+ Args:
+ view: renderer.BaseViewportView instance.
+ """
+ self._views.pop(view, None)
+
+ def clear(self):
+ """Removes all attached components."""
+ self._views = dict()
+
+ def render(self, context, viewport):
+ """Renders the overlay on screen.
+
+ Args:
+ context: MjrContext instance.
+ viewport: MJRRECT instance.
+ """
+ for view, location in self._views.items():
+ view.render(context, viewport, location)
diff --git a/dm_control/viewer/views_test.py b/dm_control/viewer/views_test.py
new file mode 100644
index 00000000..4cd57251
--- /dev/null
+++ b/dm_control/viewer/views_test.py
@@ -0,0 +1,144 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the views.py module."""
+
+
+from absl.testing import absltest
+from dm_control.viewer import views
+import mock
+import numpy as np
+
+
+class ColumnTextViewTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.model = mock.MagicMock()
+ self.view = views.ColumnTextView(self.model)
+
+ self.context = mock.MagicMock()
+ self.viewport = mock.MagicMock()
+ self.location = views.PanelLocation.TOP_LEFT
+
+ def test_rendering_empty_columns(self):
+ self.model.get_columns.return_value = []
+ with mock.patch(views.__name__ + '.mujoco') as mjlib_mock:
+ self.view.render(self.context, self.viewport, self.location)
+ self.assertEqual(0, mjlib_mock.mjr_overlay.call_count)
+
+ def test_rendering(self):
+ self.model.get_columns.return_value = [('', '')]
+ with mock.patch(views.__name__ + '.mujoco') as mjlib_mock:
+ self.view.render(self.context, self.viewport, self.location)
+ mjlib_mock.mjr_overlay.assert_called_once()
+
+
+class MujocoDepthBufferTests(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.component = views.MujocoDepthBuffer()
+
+ self.context = mock.MagicMock()
+ self.viewport = mock.MagicMock()
+
+ def test_updating_buffer_size_after_viewport_resize(self):
+ self.component._depth_buffer = np.zeros((1, 1), np.float32)
+ self.viewport.width = 10
+ self.viewport.height = 10
+
+ with mock.patch(views.__name__ + '.mujoco'):
+ self.component.render(context=self.context, viewport=self.viewport)
+ self.assertEqual((10, 10), self.component._depth_buffer.shape)
+
+ def test_reading_depth_data(self):
+ with mock.patch(views.__name__ + '.mujoco') as mjlib_mock:
+ self.component.render(context=self.context, viewport=self.viewport)
+ mjlib_mock.mjr_readPixels.assert_called_once()
+ self.assertIsNone(mjlib_mock.mjr_readPixels.call_args[0][0])
+
+ @absltest.skip('b/222664582')
+ def test_rendering_position_fixed_to_bottom_right_quarter_of_viewport(self):
+ self.viewport.width = 100
+ self.viewport.height = 100
+ expected_rect = [75, 0, 25, 25]
+ with mock.patch(views.__name__ + '.mujoco') as mjlib_mock:
+ self.component.render(context=self.context, viewport=self.viewport)
+ mjlib_mock.mjr_drawPixels.assert_called_once()
+ render_rect = mjlib_mock.mjr_drawPixels.call_args[0][2]
+ self.assertEqual(expected_rect[0], render_rect.left)
+ self.assertEqual(expected_rect[1], render_rect.bottom)
+ self.assertEqual(expected_rect[2], render_rect.width)
+ self.assertEqual(expected_rect[3], render_rect.height)
+
+
+class ViewportLayoutTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.layout = views.ViewportLayout()
+
+ self.context = mock.MagicMock()
+ self.viewport = mock.MagicMock()
+
+ def test_added_elements_need_to_be_a_view(self):
+ self.element = mock.MagicMock()
+ with self.assertRaises(TypeError):
+ self.layout.add(self.element, views.PanelLocation.TOP_LEFT)
+
+ def test_adding_component(self):
+ self.element = mock.MagicMock(spec=views.BaseViewportView)
+ self.layout.add(self.element, views.PanelLocation.TOP_LEFT)
+ self.assertLen(self.layout, 1)
+
+ def test_adding_same_component_twice_updates_location(self):
+ self.element = mock.MagicMock(spec=views.BaseViewportView)
+ self.layout.add(self.element, views.PanelLocation.TOP_LEFT)
+ self.layout.add(self.element, views.PanelLocation.TOP_RIGHT)
+ self.assertEqual(
+ views.PanelLocation.TOP_RIGHT, self.layout._views[self.element])
+
+ def test_removing_component(self):
+ self.element = mock.MagicMock(spec=views.BaseViewportView)
+ self.layout._views[self.element] = views.PanelLocation.TOP_LEFT
+ self.layout.remove(self.element)
+ self.assertEmpty(self.layout)
+
+ def test_removing_unregistered_component(self):
+ self.element = mock.MagicMock(spec=views.BaseViewportView)
+ self.layout.remove(self.element) # No error is raised
+
+ def test_clearing_layout(self):
+ pos = views.PanelLocation.TOP_LEFT
+ self.layout._views = {mock.MagicMock(spec=views.BaseViewportView): pos
+ for _ in range(3)}
+ self.layout.clear()
+ self.assertEmpty(self.layout)
+
+ def test_rendering_layout(self):
+ positions = [
+ views.PanelLocation.TOP_LEFT,
+ views.PanelLocation.TOP_RIGHT,
+ views.PanelLocation.BOTTOM_LEFT]
+ self.layout._views = {mock.MagicMock(spec=views.BaseViewportView): pos
+ for pos in positions}
+ self.layout.render(self.context, self.viewport)
+ for view, location in self.layout._views.items():
+ view.render.assert_called_once()
+ self.assertEqual(location, view.render.call_args[0][2])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/migration_guide_1.0.md b/migration_guide_1.0.md
new file mode 100644
index 00000000..bc04ab48
--- /dev/null
+++ b/migration_guide_1.0.md
@@ -0,0 +1,150 @@
+# dm_control: 1.0.0 update guide
+
+With 1.0.0, we changed the way dm_control uses the MuJoCo physics simulator, and
+migrated to new python bindings. For most users, this will require no code
+changes. However, some more advanced users will need to change their code
+slightly. Below is a list of known changes that need to be made. Please contact
+us if you've had to make any further changes that are not listed below.
+
+## Required changes
+
+`dm_control.mujoco.wrapper.mjbindings.types` should not be used This module was
+specific to the previous implementation of the dm_control python bindings. For
+example, `types.MJRRECT` should be replaced with `mujoco.MjrRect`.
+
+### `MjData.contact` has changed
+
+`MjData.contact` (often accessed as Physics.data.contact) used to offer an
+interface similar to a numpy structured array. For example, `data.contact.geom1`
+used to be a numpy array of geom IDs.
+
+With the recent update, `MjData.contact` will appear as a list of MjContact
+structs. Code that used to operate on the structured array will have to change.
+For example, the following code would get an array containing the contact
+distance for all contacts that involve geom_id:
+
+```
+contact = physics.data.contact
+involves_geom = (contact.geom1 == geom_id) | (contact.geom2 == geom_id)
+dists = contact[involves_geom].dist
+```
+
+After the upgrade:
+
+```
+contacts = physics.data.contact
+dists = [
+ c.dist for c in contacts if c.geom1 == geom_id or c.geom2 == geom_id
+]
+```
+
+### `.ptr.contents` will not work
+
+Code that accesses `.ptr.contents` on objects such as `MjvScene` will need to be
+updated. In most cases, simply using `scene` instead of `scene.ptr.contents`
+will work.
+
+### Different exceptions will be thrown from MuJoCo
+
+Code (mostly in tests) that expects `dm_control.mujoco.wrapper.core.Error`
+exceptions, will receive different exceptions, thrown by the `mujoco` library.
+These will often be `ValueError` (for errors caused by input parameters), or
+`mujoco.FatalError` (for low level errors in MuJoCo).
+
+### Better error handling
+
+The Python interpreter no longer crashes out when
+[`mju_error`](https://mujoco.readthedocs.io/en/latest/APIreference.html#mju-error)
+is called. Instead, `mju_error` calls are translated into `mujoco.FatalError`
+exceptions in Python.
+
+When Python callables are used as user-defined MuJoCo callbacks, they are now
+permitted to raise exceptions, which will be correctly propagated back down the
+Python call stack.
+
+### Change of signature for `mj_saveModel`
+
+[`mj_saveModel`](https://mujoco.readthedocs.io/en/latest/APIreference.html#mj-savemodel)
+now expects a numpy `uint8` array rather than a ctypes string buffer, and
+doesn't require a "size" parameter (it's inferred from the numpy array size).
+
+Before:
+```
+model_size = mjlib.mj_sizeModel(model.ptr)
+buf = ctypes.create_string_buffer(model_size)
+mjlib.mj_saveModel(model.ptr, None, buf, model_size)
+```
+
+After:
+```
+model_size = mujoco.mj_sizeModel(model)
+buf = np.empty(model_size, np.uint8)
+mjlib.mj_saveModel(model.ptr, None, buf)
+```
+
+## Optional changes
+
+The following are some changes that can make your code more concise, but are not
+required for it to continue working.
+
+### Use the mujoco module directly, instead of mjlib
+
+Existing code that uses `dm_control.mujoco.wrapper.mjbindings.mjlib` can
+directly replace these modules with mujoco. Code that uses `enums` or
+`constants` from `dm_control.mujoco.wrapper.mjbindings` can also use mujoco,
+with slight type changes. All mujoco functions will accept the old enum values
+or the new ones.
+
+Before:
+```
+import dm_control.mujoco.wrapper.mjbindings
+mjlib = mjbindings.mjlib
+
+mjlib.mj_objectVelocity(
+ physics.model.ptr, physics.data.ptr,
+ enums.mjtObj.mjOBJ_SITE,
+ site_id, vel, 0)
+```
+
+After:
+```
+import mujoco
+
+mujoco.mj_objectVelocity(
+ physics.model.ptr, physics.data.ptr,
+ mujoco.mjtObj.mjOBJ_SITE,
+ site_id, vel, 0)
+```
+
+### Assume structs are correctly initialized and memory is managed
+
+The MuJoCo C API includes functions that manage the memory for certain structs.
+Those include functions that allocate memory (e.g. `mj_makeModel`,
+`mj_makeData`, `mjv_makeScene`), functions that free memory (e.g.
+`mj_deleteModel`, `mj_deleteData`, `mjv_freeScene`), and functions that reset a
+struct to its default value (e.g. `mjv_defaultOption`, `mj_defaultVisual`).
+
+The new Python bindings take care of this. Wrapper classes like
+`mujoco.MjvScene` will automatically allocate memory when they're created, and
+release it when they're deleted, and be created with default values set.
+
+As such, allocating and freeing functions are not available through the mujoco
+Python bindings. The "default" functions are still available, but in most cases
+the calls can simply be removed.
+
+Before:
+```
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.wrapper import mjbindings
+mjlib = mjbindings.mjlib
+
+scene_option = wrapper.core.MjvOption()
+mjlib.mjv_defaultOption(scene_option.ptr)
+```
+
+After:
+```
+from dm_control.mujoco import wrapper
+
+scene_option = wrapper.core.MjvOption()
+```
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..c89a21c0
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["mujoco", "setuptools >= 40.6.0", "wheel", "pyparsing >= 3.0", "absl-py >= 0.7.0"]
+build-backend = "setuptools.build_meta"
diff --git a/requirements.txt b/requirements.txt
index 8a20ce97..c14939ce 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,12 +1,25 @@
-absl-py==0.1.5
-enum34==1.1.6
-future==0.16.0
-glfw==1.4.0
-lxml==4.1.1
-mock==2.0.0
-nose==1.3.7
-numpy==1.13.3
-pillow==5.0.0
-pyparsing==2.2.0
-scipy==1.0.0
-six==1.11.0
+absl-py==2.3.1
+dm-env==1.6
+dm-tree==0.1.9
+glfw==2.9.0
+h5py==3.14.0
+labmaze==1.0.6
+lxml==6.0.1
+mock==5.2.0
+mujoco==3.3.7
+numpy==2.3.2; python_version >= '3.11'
+numpy==2.2.6; python_version == '3.10'
+numpy==2.0.2; python_version == '3.9'
+pillow==11.3.0
+protobuf==3.19.4
+pyopengl==3.1.10
+pyparsing==3.2.3
+pytest==8.4.1
+pytest-xdist==3.8.0
+pytest-timeout==2.4.0
+requests==2.32.5
+scipy==1.16.1; python_version >= '3.11'
+scipy==1.15.3; python_version == '3.10'
+scipy==1.13.1; python_version == '3.9'
+setuptools==80.9.0
+tqdm==4.67.1
diff --git a/setup.py b/setup.py
index 40068344..88bf8bea 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The dm_control Authors.
+# Copyright 2017-2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,18 +15,20 @@
"""Install script for setuptools."""
+import fnmatch
+import logging
import os
+import platform
import subprocess
import sys
-from distutils import cmd
-from distutils import log
+import mujoco
+import setuptools
from setuptools import find_packages
from setuptools import setup
from setuptools.command import install
-from setuptools.command import test
-DEFAULT_HEADERS_DIR = '~/.mujoco/mjpro150/include'
+PLATFORM = platform.system()
# Relative paths to the binding generator script and the output directory.
AUTOWRAP_PATH = 'dm_control/autowrap/autowrap.py'
@@ -39,6 +41,8 @@
'mjdata.h',
'mjmodel.h',
'mjrender.h',
+ 'mjtnum.h',
+ 'mjui.h',
'mjvisualize.h',
'mjxmacro.h',
'mujoco.h',
@@ -49,29 +53,33 @@ def _initialize_mjbindings_options(cmd_instance):
"""Set default values for options relating to `build_mjbindings`."""
# A default value must be assigned to each user option here.
cmd_instance.inplace = 0
- cmd_instance.headers_dir = os.path.expanduser(DEFAULT_HEADERS_DIR)
+ cmd_instance.headers_dir = mujoco.HEADERS_DIR
def _finalize_mjbindings_options(cmd_instance):
"""Post-process options relating to `build_mjbindings`."""
+ headers_dir = os.path.expanduser(cmd_instance.headers_dir)
header_paths = []
for filename in HEADER_FILENAMES:
- full_path = os.path.join(cmd_instance.headers_dir, filename)
+ full_path = os.path.join(headers_dir, filename)
if not os.path.exists(full_path):
raise IOError('Header file {!r} does not exist.'.format(full_path))
header_paths.append(full_path)
cmd_instance.header_paths = ' '.join(header_paths)
-class BuildMJBindingsCommand(cmd.Command):
+class BuildMJBindingsCommand(setuptools.Command):
"""Runs `autowrap.py` to generate the low-level ctypes bindings for MuJoCo."""
+
description = __doc__
user_options = [
# The format is (long option, short option, description).
- ('headers-dir=', None,
- 'Path to directory containing MuJoCo headers.'),
- ('inplace=', None,
- 'Place generated files in source directory rather than `build-lib`.'),
+ ('headers-dir=', None, 'Path to directory containing MuJoCo headers.'),
+ (
+ 'inplace=',
+ None,
+ 'Place generated files in source directory rather than `build-lib`.',
+ ),
]
boolean_options = ['inplace']
@@ -93,9 +101,9 @@ def run(self):
sys.executable or 'python',
AUTOWRAP_PATH,
'--header_paths={}'.format(self.header_paths),
- '--output_dir={}'.format(output_dir)
+ '--output_dir={}'.format(output_dir),
]
- self.announce('Running command: {}'.format(command), level=log.DEBUG)
+ self.announce('Running command: {}'.format(command), level=logging.DEBUG)
try:
# Prepend the current directory to $PYTHONPATH so that internal imports
# in `autowrap` can succeed before we've installed anything.
@@ -113,9 +121,11 @@ class InstallCommand(install.install):
"""Runs 'build_mjbindings' before installation."""
user_options = (
- install.install.user_options + BuildMJBindingsCommand.user_options)
+ install.install.user_options + BuildMJBindingsCommand.user_options
+ )
boolean_options = (
- install.install.boolean_options + BuildMJBindingsCommand.boolean_options)
+ install.install.boolean_options + BuildMJBindingsCommand.boolean_options
+ )
def initialize_options(self):
install.install.initialize_options(self)
@@ -126,57 +136,103 @@ def finalize_options(self):
_finalize_mjbindings_options(self)
def run(self):
- self.reinitialize_command('build_mjbindings',
- inplace=self.inplace,
- headers_dir=self.headers_dir)
+ self.reinitialize_command('build_mjbindings')
self.run_command('build_mjbindings')
install.install.run(self)
-class TestCommand(test.test):
- """Prepends path to generated sources before running unit tests."""
+def find_data_files(package_dir, patterns, excludes=()):
+ """Recursively finds files whose names match the given shell patterns."""
+ paths = set()
+
+ def is_excluded(s):
+ for exclude in excludes:
+ if fnmatch.fnmatch(s, exclude):
+ return True
+ return False
+
+ for directory, _, filenames in os.walk(package_dir):
+ if is_excluded(directory):
+ continue
+ for pattern in patterns:
+ for filename in fnmatch.filter(filenames, pattern):
+ # NB: paths must be relative to the package directory.
+ relative_dirpath = os.path.relpath(directory, package_dir)
+ full_path = os.path.join(relative_dirpath, filename)
+ if not is_excluded(full_path):
+ paths.add(full_path)
+ return list(paths)
- def run(self):
- # Generate ctypes bindings in-place so that they can be imported in tests.
- self.reinitialize_command('build_mjbindings', inplace=1)
- self.run_command('build_mjbindings')
- test.test.run(self)
setup(
name='dm_control',
+ version='1.0.35',
description='Continuous control environments and MuJoCo Python bindings.',
+ long_description="""
+# `dm_control`: DeepMind Infrastructure for Physics-Based Simulation.
+
+DeepMind's software stack for physics-based simulation and Reinforcement
+Learning environments, using MuJoCo physics.
+
+An **introductory tutorial** for this package is available as a Colaboratory
+notebook: [Open In Google Colab](https://colab.research.google.com/github/google-deepmind/dm_control/blob/main/tutorial.ipynb).
+""",
+ long_description_content_type='text/markdown',
author='DeepMind',
- license='Apache License, Version 2.0',
+ author_email='mujoco@deepmind.com',
+ url='https://github.com/google-deepmind/dm_control',
+ license='Apache-2.0',
keywords='machine learning control physics MuJoCo AI',
+ python_requires='>=3.9',
install_requires=[
- 'absl-py',
- 'enum34',
- 'future',
+ 'absl-py>=0.7.0',
+ 'dm-env',
+ 'dm-tree != 0.1.2',
'glfw',
+ 'labmaze',
'lxml',
- 'numpy',
- 'pyparsing',
- 'setuptools',
- 'six',
+ 'mujoco >= 3.3.7',
+ 'numpy >= 1.9.0',
+ 'protobuf >= 3.19.4',
+ 'pyopengl >= 3.1.4',
+ 'pyparsing >= 3.0.0',
+ 'requests',
+ 'setuptools!=50.0.0', # https://github.com/pypa/setuptools/issues/2350
+ 'scipy',
+ 'tqdm',
],
+ extras_require={
+ 'HDF5': ['h5py'],
+ },
tests_require=[
'mock',
- 'nose',
- 'pillow',
- 'scipy',
+ 'pytest',
+ 'pillow>=10.2.0',
],
- test_suite='nose.collector',
packages=find_packages(),
package_data={
- 'dm_control.mujoco.testing':
- ['assets/*.png', 'assets/frames/*.png', 'assets/*.stl', 'assets/*.xml'],
- 'dm_control.suite':
- ['*.xml', 'common/*.xml'],
+ 'dm_control': find_data_files(
+ package_dir='dm_control',
+ patterns=[
+ '*.amc',
+ '*.msh',
+ '*.png',
+ '*.skn',
+ '*.stl',
+ '*.xml',
+ '*.textproto',
+ '*.h5',
+ '*.zip',
+ ],
+ excludes=[
+ '*/dog_assets/extras/*',
+ '*/kinova/meshes/*', # Exclude non-decimated meshes.
+ ],
+ ),
},
cmdclass={
'build_mjbindings': BuildMJBindingsCommand,
'install': InstallCommand,
- 'test': TestCommand,
},
entry_points={},
)
diff --git a/tech_report.pdf b/tech_report.pdf
deleted file mode 100644
index 904aa0eb..00000000
Binary files a/tech_report.pdf and /dev/null differ
diff --git a/tutorial.ipynb b/tutorial.ipynb
new file mode 100644
index 00000000..495789d1
--- /dev/null
+++ b/tutorial.ipynb
@@ -0,0 +1,1848 @@
+{
+ "cells": [
+ {
+ "metadata": {
+ "id": "MpkYHwCqk7W-"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# **`dm_control` tutorial**\n",
+ "\n",
+ "[](https://colab.research.google.com/github/google-deepmind/dm_control/blob/main/tutorial.ipynb)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "_UbO9uhtBSX5"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eCopyright 2020 The dm_control Authors.\u003c/small\u003e\u003c/p\u003e\n",
+ "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at \u003ca href=\"http://www.apache.org/licenses/LICENSE-2.0\"\u003ehttp://www.apache.org/licenses/LICENSE-2.0\u003c/a\u003e.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e\n",
+ "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "aThGKGp0cD76"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This notebook provides an overview tutorial of DeepMind's `dm_control` package, hosted at the [google-deepmind/dm_control](https://github.com/google-deepmind/dm_control) repository on GitHub.\n",
+ "\n",
+ "It is adjunct to this [tech report](http://arxiv.org/abs/2006.12983).\n",
+ "\n",
+ "**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime \u003e Change runtime type\"."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "YkBQUjm6gbGF"
+ },
+ "cell_type": "markdown",
+ "source": [
+ ""
+ ]
+ },
+ {
+ "metadata": {
+ "id": "YvyGCsgSCxHQ"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "### Installing `dm_control` on Colab"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "IbZxYDxzoz5R"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Run to install MuJoCo and `dm_control`\n",
+ "import distutils.util\n",
+ "import os\n",
+ "import subprocess\n",
+ "if subprocess.run('nvidia-smi').returncode:\n",
+ " raise RuntimeError(\n",
+ " 'Cannot communicate with GPU. '\n",
+ " 'Make sure you are using a GPU Colab runtime. '\n",
+ " 'Go to the Runtime menu and select Choose runtime type.')\n",
+ "\n",
+ "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
+ "# This is usually installed as part of an Nvidia driver package, but the Colab\n",
+ "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
+ "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
+ "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
+ "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
+ " with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
+ " f.write(\"\"\"{\n",
+ " \"file_format_version\" : \"1.0.0\",\n",
+ " \"ICD\" : {\n",
+ " \"library_path\" : \"libEGL_nvidia.so.0\"\n",
+ " }\n",
+ "}\n",
+ "\"\"\")\n",
+ "\n",
+ "print('Installing dm_control...')\n",
+ "!pip install -q dm_control\u003e=1.0.35\n",
+ "\n",
+ "# Configure dm_control to use the EGL rendering backend (requires GPU)\n",
+ "%env MUJOCO_GL=egl\n",
+ "\n",
+ "print('Checking that the dm_control installation succeeded...')\n",
+ "try:\n",
+ " from dm_control import suite\n",
+ " env = suite.load('cartpole', 'swingup')\n",
+ " pixels = env.physics.render()\n",
+ "except Exception as e:\n",
+ " raise e from RuntimeError(\n",
+ " 'Something went wrong during installation. Check the shell output above '\n",
+ " 'for more information.\\n'\n",
+ " 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
+ " 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n",
+ "else:\n",
+ " del pixels, suite\n",
+ "\n",
+ "!echo Installed dm_control $(pip show dm_control | grep -Po \"(?\u003c=Version: ).+\")"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "wtDN43hIJh2C"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Imports\n",
+ "\n",
+ "Run both of these cells:"
+ ]
+ },
+ {
+ "metadata": {
+ "cellView": "form",
+ "id": "T5f4w3Kq2X14"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title All `dm_control` imports required for this tutorial\n",
+ "\n",
+ "# The basic mujoco wrapper.\n",
+ "from dm_control import mujoco\n",
+ "\n",
+ "# Access to enums and MuJoCo library functions.\n",
+ "from dm_control.mujoco.wrapper.mjbindings import enums\n",
+ "from dm_control.mujoco.wrapper.mjbindings import mjlib\n",
+ "\n",
+ "# PyMJCF\n",
+ "from dm_control import mjcf\n",
+ "\n",
+ "# Composer high level imports\n",
+ "from dm_control import composer\n",
+ "from dm_control.composer.observation import observable\n",
+ "from dm_control.composer import variation\n",
+ "\n",
+ "# Imports for Composer tutorial example\n",
+ "from dm_control.composer.variation import distributions\n",
+ "from dm_control.composer.variation import noises\n",
+ "from dm_control.locomotion.arenas import floors\n",
+ "\n",
+ "# Control Suite\n",
+ "from dm_control import suite\n",
+ "\n",
+ "# Run through corridor example\n",
+ "from dm_control.locomotion.walkers import cmu_humanoid\n",
+ "from dm_control.locomotion.arenas import corridors as corridor_arenas\n",
+ "from dm_control.locomotion.tasks import corridors as corridor_tasks\n",
+ "\n",
+ "# Soccer\n",
+ "from dm_control.locomotion import soccer\n",
+ "\n",
+ "# Manipulation\n",
+ "from dm_control import manipulation"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "gKc1FNhKiVJX"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Other imports and helper functions\n",
+ "\n",
+ "# General\n",
+ "import copy\n",
+ "import os\n",
+ "import itertools\n",
+ "from IPython.display import clear_output\n",
+ "import numpy as np\n",
+ "\n",
+ "# Graphics-related\n",
+ "import matplotlib\n",
+ "import matplotlib.animation as animation\n",
+ "import matplotlib.pyplot as plt\n",
+ "from IPython.display import HTML\n",
+ "import PIL.Image\n",
+ "# Internal loading of video libraries.\n",
+ "\n",
+ "# Use svg backend for figure rendering\n",
+ "%config InlineBackend.figure_format = 'svg'\n",
+ "\n",
+ "# Font sizes\n",
+ "SMALL_SIZE = 8\n",
+ "MEDIUM_SIZE = 10\n",
+ "BIGGER_SIZE = 12\n",
+ "plt.rc('font', size=SMALL_SIZE) # controls default text sizes\n",
+ "plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title\n",
+ "plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels\n",
+ "plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n",
+ "plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n",
+ "plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize\n",
+ "plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title\n",
+ "\n",
+ "# Inline video helper function\n",
+ "if os.environ.get('COLAB_NOTEBOOK_TEST', False):\n",
+ " # We skip video generation during tests, as it is quite expensive.\n",
+ " display_video = lambda *args, **kwargs: None\n",
+ "else:\n",
+ " def display_video(frames, framerate=30):\n",
+ " height, width, _ = frames[0].shape\n",
+ " dpi = 70\n",
+ " orig_backend = matplotlib.get_backend()\n",
+ " matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering.\n",
+ " fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)\n",
+ " matplotlib.use(orig_backend) # Switch back to the original backend.\n",
+ " ax.set_axis_off()\n",
+ " ax.set_aspect('equal')\n",
+ " ax.set_position([0, 0, 1, 1])\n",
+ " im = ax.imshow(frames[0])\n",
+ " def update(frame):\n",
+ " im.set_data(frame)\n",
+ " return [im]\n",
+ " interval = 1000/framerate\n",
+ " anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n",
+ " interval=interval, blit=True, repeat=False)\n",
+ " return HTML(anim.to_html5_video())\n",
+ "\n",
+ "# Seed numpy's global RNG so that cell outputs are deterministic. We also try to\n",
+ "# use RandomState instances that are local to a single cell wherever possible.\n",
+ "np.random.seed(42)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "jZXz9rPYGA-Y"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Model definition, compilation and rendering\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "MRBaZsf1d7Gb"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "We begin by describing some basic concepts of the [MuJoCo](http://mujoco.org/) physics simulation library, but recommend the [official documentation](http://mujoco.org/book/) for details.\n",
+ "\n",
+ "Let's define a simple model with two geoms and a light."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "ZS2utl59ZTsr"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title A static model {vertical-output: true}\n",
+ "\n",
+ "static_model = \"\"\"\n",
+ "\u003cmujoco\u003e\n",
+ " \u003cworldbody\u003e\n",
+ " \u003clight name=\"top\" pos=\"0 0 1\"/\u003e\n",
+ " \u003cgeom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/\u003e\n",
+ " \u003cgeom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/\u003e\n",
+ " \u003c/worldbody\u003e\n",
+ "\u003c/mujoco\u003e\n",
+ "\"\"\"\n",
+ "physics = mujoco.Physics.from_xml_string(static_model)\n",
+ "pixels = physics.render()\n",
+ "PIL.Image.fromarray(pixels)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "p4vPllljTJh8"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "`static_model` is written in MuJoCo's XML-based [MJCF](http://www.mujoco.org/book/modeling.html) modeling language. The `from_xml_string()` method invokes the model compiler, which instantiates the library's internal data structures. These can be accessed via the `physics` object, see below."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "MdUF2UYmR4TA"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Adding DOFs and simulating, advanced rendering\n",
+ "This is a perfectly legitimate model, but if we simulate it, nothing will happen except for time advancing. This is because this model has no degrees of freedom (DOFs). We add DOFs by adding **joints** to bodies, specifying how they can move with respect to their parents. Let us add a hinge joint and re-render, visualizing the joint axis."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "R7zokzd_yeEg"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title A child body with a joint { vertical-output: true }\n",
+ "\n",
+ "swinging_body = \"\"\"\n",
+ "\u003cmujoco\u003e\n",
+ " \u003cworldbody\u003e\n",
+ " \u003clight name=\"top\" pos=\"0 0 1\"/\u003e\n",
+ " \u003cbody name=\"box_and_sphere\" euler=\"0 0 -30\"\u003e\n",
+ " \u003cjoint name=\"swing\" type=\"hinge\" axis=\"1 -1 0\" pos=\"-.2 -.2 -.2\"/\u003e\n",
+ " \u003cgeom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/\u003e\n",
+ " \u003cgeom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/\u003e\n",
+ " \u003c/body\u003e\n",
+ " \u003c/worldbody\u003e\n",
+ "\u003c/mujoco\u003e\n",
+ "\"\"\"\n",
+ "physics = mujoco.Physics.from_xml_string(swinging_body)\n",
+ "# Visualize the joint axis.\n",
+ "scene_option = mujoco.wrapper.core.MjvOption()\n",
+ "scene_option.flags[enums.mjtVisFlag.mjVIS_JOINT] = True\n",
+ "pixels = physics.render(scene_option=scene_option)\n",
+ "PIL.Image.fromarray(pixels)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "INOGGV0PTQus"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The things that move (and which have inertia) are called *bodies*. The body's child `joint` specifies how that body can move with respect to its parent, in this case `box_and_sphere` w.r.t the `worldbody`. \n",
+ "\n",
+ "Note that the body's frame is **rotated** with an `euler` directive, and its children, the geoms and the joint, rotate with it. This is to emphasize the local-to-parent-frame nature of position and orientation directives in MJCF.\n",
+ "\n",
+ "Let's make a video, to get a sense of the dynamics and to see the body swinging under gravity."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Z_57VMUDpGrj"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Making a video {vertical-output: true}\n",
+ "\n",
+ "duration = 2 # (seconds)\n",
+ "framerate = 30 # (Hz)\n",
+ "\n",
+ "# Visualize the joint axis\n",
+ "scene_option = mujoco.wrapper.core.MjvOption()\n",
+ "scene_option.flags[enums.mjtVisFlag.mjVIS_JOINT] = True\n",
+ "\n",
+ "# Simulate and display video.\n",
+ "frames = []\n",
+ "physics.reset() # Reset state and time\n",
+ "while physics.data.time \u003c duration:\n",
+ " physics.step()\n",
+ " if len(frames) \u003c physics.data.time * framerate:\n",
+ " pixels = physics.render(scene_option=scene_option)\n",
+ " frames.append(pixels)\n",
+ "display_video(frames, framerate)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "yYvS1UaciMX_"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Note how we collect the video frames. Because physics simulation timesteps are generally much smaller than framerates (the default timestep is 2ms), we don't render after each step."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "nQ8XOnRQx7T1"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Rendering options\n",
+ "\n",
+ "Like joint visualisation, additional rendering options are exposed as parameters to the `render` method."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "AQITZiIgx7T2"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Enable transparency and frame visualization {vertical-output: true}\n",
+ "\n",
+ "scene_option = mujoco.wrapper.core.MjvOption()\n",
+ "scene_option.frame = enums.mjtFrame.mjFRAME_GEOM\n",
+ "scene_option.flags[enums.mjtVisFlag.mjVIS_TRANSPARENT] = True\n",
+ "pixels = physics.render(scene_option=scene_option)\n",
+ "PIL.Image.fromarray(pixels)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "PDDgY48vx7T6"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Depth rendering {vertical-output: true}\n",
+ "\n",
+ "# depth is a float array, in meters.\n",
+ "depth = physics.render(depth=True)\n",
+ "# Shift nearest values to the origin.\n",
+ "depth -= depth.min()\n",
+ "# Scale by 2 mean distances of near rays.\n",
+ "depth /= 2*depth[depth \u003c= 1].mean()\n",
+ "# Scale to [0, 255]\n",
+ "pixels = 255*np.clip(depth, 0, 1)\n",
+ "PIL.Image.fromarray(pixels.astype(np.uint8))"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "PNwiIrgpx7T8"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Segmentation rendering {vertical-output: true}\n",
+ "\n",
+ "seg = physics.render(segmentation=True)\n",
+ "# Display the contents of the first channel, which contains object\n",
+ "# IDs. The second channel, seg[:, :, 1], contains object types.\n",
+ "geom_ids = seg[:, :, 0]\n",
+ "# Infinity is mapped to -1\n",
+ "geom_ids = geom_ids.astype(np.float64) + 1\n",
+ "# Scale to [0, 1]\n",
+ "geom_ids = geom_ids / geom_ids.max()\n",
+ "pixels = 255*geom_ids\n",
+ "PIL.Image.fromarray(pixels.astype(np.uint8))\n"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "uCJQlv3cQcJQ"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Projecting from world to camera coordinates {vertical-output: true}\n",
+ "\n",
+ "# Get the world coordinates of the box corners\n",
+ "box_pos = physics.named.data.geom_xpos['red_box']\n",
+ "box_mat = physics.named.data.geom_xmat['red_box'].reshape(3, 3)\n",
+ "box_size = physics.named.model.geom_size['red_box']\n",
+ "offsets = np.array([-1, 1]) * box_size[:, None]\n",
+ "xyz_local = np.stack(list(itertools.product(*offsets))).T\n",
+ "xyz_global = box_pos[:, None] + box_mat @ xyz_local\n",
+ "\n",
+ "# Camera matrices multiply homogenous [x, y, z, 1] vectors.\n",
+ "corners_homogeneous = np.ones((4, xyz_global.shape[1]), dtype=float)\n",
+ "corners_homogeneous[:3, :] = xyz_global\n",
+ "\n",
+ "# Get the camera matrix.\n",
+ "camera = mujoco.Camera(physics)\n",
+ "camera_matrix = camera.matrix\n",
+ "\n",
+ "# Project world coordinates into pixel space. See:\n",
+ "# https://en.wikipedia.org/wiki/3D_projection#Mathematical_formula\n",
+ "xs, ys, s = camera_matrix @ corners_homogeneous\n",
+ "# x and y are in the pixel coordinate system.\n",
+ "x = xs / s\n",
+ "y = ys / s\n",
+ "\n",
+ "# Render the camera view and overlay the projected corner coordinates.\n",
+ "pixels = camera.render()\n",
+ "fig, ax = plt.subplots(1, 1)\n",
+ "ax.imshow(pixels)\n",
+ "ax.plot(x, y, '+', c='w')\n",
+ "ax.set_axis_off()"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "gf9h_wi9weet"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# MuJoCo basics and named indexing"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "NCcZxrDDB1Cj"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## `mjModel`\n",
+ "MuJoCo's `mjModel`, encapsulated in `physics.model`, contains the *model description*, including the default initial state and other fixed quantities which are not a function of the state, e.g. the positions of geoms in the frame of their parent body. The (x, y, z) offsets of the `box` and `sphere` geoms, relative their parent body `box_and_sphere` are given by `model.geom_pos`:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "wx8NANvOF8g1"
+ },
+ "cell_type": "code",
+ "source": [
+ "physics.model.geom_pos"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "Wee5ATLtIQn_"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The `model.opt` structure contains global quantities like"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BhzbZIfDIU2-"
+ },
+ "cell_type": "code",
+ "source": [
+ "print('timestep', physics.model.opt.timestep)\n",
+ "print('gravity', physics.model.opt.gravity)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "t5hY0fyXFLcf"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## `mjData`\n",
+ "`mjData`, encapsulated in `physics.data`, contains the *state* and quantities that depend on it. The state is made up of time, generalized positions and generalised velocities. These are respectively `data.time`, `data.qpos` and `data.qvel`. \n",
+ "\n",
+ "Let's print the state of the swinging body where we left it:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "acwZtDwp9mQU"
+ },
+ "cell_type": "code",
+ "source": [
+ "print(physics.data.time, physics.data.qpos, physics.data.qvel)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "7YlmcLcA-WQu"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "`physics.data` also contains functions of the state, for example the cartesian positions of objects in the world frame. The (x, y, z) positions of our two geoms are in `data.geom_xpos`:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "CPwDcAQ0-uUE"
+ },
+ "cell_type": "code",
+ "source": [
+ "print(physics.data.geom_xpos)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "Z0UodCxS_v49"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Named indexing\n",
+ "\n",
+ "The semantics of the above arrays are made clearer using the `named` wrapper, which assigns names to rows and type names to columns."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "cLARcaK6-xCU"
+ },
+ "cell_type": "code",
+ "source": [
+ "print(physics.named.data.geom_xpos)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "wgXOUZNZHIx6"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Note how `model.geom_pos` and `data.geom_xpos` have similar semantics but very different meanings."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "-cW61ClRHS8a"
+ },
+ "cell_type": "code",
+ "source": [
+ "print(physics.named.model.geom_pos)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "-lQ0AChVASMv"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Name strings can be used to index **into** the relevant quantities, making code much more readable and robust."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Rj4ad9fQAnFZ"
+ },
+ "cell_type": "code",
+ "source": [
+ "physics.named.data.geom_xpos['green_sphere', 'z']"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "axr_p6APAzFn"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Joint names can be used to index into quantities in configuration space (beginning with the letter `q`):"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "hluF9aDG9O1W"
+ },
+ "cell_type": "code",
+ "source": [
+ "physics.named.data.qpos['swing']"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "3IhfyD2Q1pjv"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "We can mix NumPy slicing operations with named indexing. As an example, we can set the color of the box using its name (`\"red_box\"`) as an index into the rows of the `geom_rgba` array. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "f5vVUullUvWH"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Changing colors using named indexing{vertical-output: true}\n",
+ "\n",
+ "random_rgb = np.random.rand(3)\n",
+ "physics.named.model.geom_rgba['red_box', :3] = random_rgb\n",
+ "pixels = physics.render()\n",
+ "PIL.Image.fromarray(pixels)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "elzPPdq-KhLI"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Note that while `physics.model` quantities will not be changed by the engine, we can change them ourselves between steps. This however is generally not recommended, the preferred approach being to modify the model at the XML level using the PyMJCF library, see below."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "22ENjtVuhwsm"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Setting the state with `reset_context()`\n",
+ "\n",
+ "In order for `data` quantities that are functions of the state to be in sync with the state, MuJoCo's `mj_step1()` needs to be called. This is facilitated by the `reset_context()` context, please see in-depth discussion in Section 2.1 of the [tech report](https://arxiv.org/abs/2006.12983)."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WBPprCtWgXFN"
+ },
+ "cell_type": "code",
+ "source": [
+ "physics.named.data.qpos['swing'] = np.pi\n",
+ "print('Without reset_context, spatial positions are not updated:',\n",
+ " physics.named.data.geom_xpos['green_sphere', ['z']])\n",
+ "with physics.reset_context():\n",
+ " physics.named.data.qpos['swing'] = np.pi\n",
+ "print('After reset_context, positions are up-to-date:',\n",
+ " physics.named.data.geom_xpos['green_sphere', ['z']])"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "SHppAOjvSupc"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Free bodies: the self-inverting \"tippe-top\"\n",
+ "\n",
+ "A free body is a body with a `free` joint, with 6 movement DOFs: 3 translations and 3 rotations. We could give our `box_and_sphere` body a free joint and watch it fall, but let's look at something more interesting. A \"tippe top\" is a spinning toy which flips itself on its head ([Wikipedia](https://en.wikipedia.org/wiki/Tippe_top)). We model it as follows:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "xasXQpVMjIwA"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title The \"tippe-top\" model{vertical-output: true}\n",
+ "\n",
+ "tippe_top = \"\"\"\n",
+ "\u003cmujoco model=\"tippe top\"\u003e\n",
+ " \u003coption integrator=\"RK4\"/\u003e\n",
+ " \u003casset\u003e\n",
+ " \u003ctexture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".1 .2 .3\"\n",
+ " rgb2=\".2 .3 .4\" width=\"300\" height=\"300\"/\u003e\n",
+ " \u003cmaterial name=\"grid\" texture=\"grid\" texrepeat=\"8 8\" reflectance=\".2\"/\u003e\n",
+ " \u003c/asset\u003e\n",
+ " \u003cworldbody\u003e\n",
+ " \u003cgeom size=\".2 .2 .01\" type=\"plane\" material=\"grid\"/\u003e\n",
+ " \u003clight pos=\"0 0 .6\"/\u003e\n",
+ " \u003ccamera name=\"closeup\" pos=\"0 -.1 .07\" xyaxes=\"1 0 0 0 1 2\"/\u003e\n",
+ " \u003cbody name=\"top\" pos=\"0 0 .02\"\u003e\n",
+ " \u003cfreejoint/\u003e\n",
+ " \u003cgeom name=\"ball\" type=\"sphere\" size=\".02\" /\u003e\n",
+ " \u003cgeom name=\"stem\" type=\"cylinder\" pos=\"0 0 .02\" size=\"0.004 .008\"/\u003e\n",
+ " \u003cgeom name=\"ballast\" type=\"box\" size=\".023 .023 0.005\" pos=\"0 0 -.015\"\n",
+ " contype=\"0\" conaffinity=\"0\" group=\"3\"/\u003e\n",
+ " \u003c/body\u003e\n",
+ " \u003c/worldbody\u003e\n",
+ " \u003ckeyframe\u003e\n",
+ " \u003ckey name=\"spinning\" qpos=\"0 0 0.02 1 0 0 0\" qvel=\"0 0 0 0 1 200\" /\u003e\n",
+ " \u003c/keyframe\u003e\n",
+ "\u003c/mujoco\u003e\n",
+ "\"\"\"\n",
+ "physics = mujoco.Physics.from_xml_string(tippe_top)\n",
+ "PIL.Image.fromarray(physics.render(camera_id='closeup'))"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "bvHlr6maJYIG"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Note several new features of this model definition:\n",
+ "0. The free joint is added with the `\u003cfreejoint/\u003e` clause, which is similar to `\u003cjoint type=\"free\"/\u003e`, but prohibits unphysical attributes like friction or stiffness.\n",
+ "1. We use the `\u003coption/\u003e` clause to set the integrator to the more accurate Runge Kutta 4th order.\n",
+ "2. We define the floor's grid material inside the `\u003casset/\u003e` clause and reference it in the floor geom. \n",
+ "3. We use an invisible and non-colliding box geom called `ballast` to move the top's center-of-mass lower. Having a low center of mass is (counter-intuitively) required for the flipping behaviour to occur.\n",
+ "4. We save our initial spinning state as a keyframe. It has a high rotational velocity around the z-axis, but is not perfectly oriented with the world.\n",
+ "5. We define a `\u003ccamera\u003e` in our model, and then render from it using the `camera_id` argument to `render()`.\n",
+ "Let us examine the state:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "o4S9nYhHOKmb"
+ },
+ "cell_type": "code",
+ "source": [
+ "print('positions', physics.data.qpos)\n",
+ "print('velocities', physics.data.qvel)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "71UgzBAqWdtZ"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The velocities are easy to interpret, 6 zeros, one for each DOF. What about the length-7 positions? We can see the initial 2cm height of the body; the subsequent four numbers are the 3D orientation, defined by a *unit quaternion*. These normalized four-vectors, which preserve the topology of the orientation group, are the reason that `data.qpos` can be bigger than `data.qvel`: 3D orientations are represented with **4** numbers while angular velocities are **3** numbers."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "5P4HkhKNGQvs"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Video of the tippe-top {vertical-output: true}\n",
+ "\n",
+ "duration = 7 # (seconds)\n",
+ "framerate = 60 # (Hz)\n",
+ "\n",
+ "# Simulate and display video.\n",
+ "frames = []\n",
+ "physics.reset(0) # Reset to keyframe 0 (load a saved state).\n",
+ "while physics.data.time \u003c duration:\n",
+ " physics.step()\n",
+ " if len(frames) \u003c (physics.data.time) * framerate:\n",
+ " pixels = physics.render(camera_id='closeup')\n",
+ " frames.append(pixels)\n",
+ "\n",
+ "display_video(frames, framerate)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "rRuFKD2ubPgu"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "### Measuring values from `physics.data`\n",
+ "The `physics.data` structure contains all of the dynamic variables and intermediate results produced by the simulation. These are expected to change on each timestep. \n",
+ "\n",
+ "Below we simulate for 2000 timesteps and plot the state and height of the sphere as a function of time."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "1XXB6asJoZ2N"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Measuring values {vertical-output: true}\n",
+ "\n",
+ "timevals = []\n",
+ "angular_velocity = []\n",
+ "stem_height = []\n",
+ "\n",
+ "# Simulate and save data\n",
+ "physics.reset(0)\n",
+ "while physics.data.time \u003c duration:\n",
+ " physics.step()\n",
+ " timevals.append(physics.data.time)\n",
+ " angular_velocity.append(physics.data.qvel[3:6].copy())\n",
+ " stem_height.append(physics.named.data.geom_xpos['stem', 'z'])\n",
+ "\n",
+ "dpi = 100\n",
+ "width = 480\n",
+ "height = 640\n",
+ "figsize = (width / dpi, height / dpi)\n",
+ "_, ax = plt.subplots(2, 1, figsize=figsize, dpi=dpi, sharex=True)\n",
+ "\n",
+ "ax[0].plot(timevals, angular_velocity)\n",
+ "ax[0].set_title('angular velocity')\n",
+ "ax[0].set_ylabel('radians / second')\n",
+ "\n",
+ "ax[1].plot(timevals, stem_height)\n",
+ "ax[1].set_xlabel('time (seconds)')\n",
+ "ax[1].set_ylabel('meters')\n",
+ "_ = ax[1].set_title('stem height')"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "UAMItwu8e1WR"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# PyMJCF tutorial\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "hPiY8m3MssKM"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This library provides a Python object model for MuJoCo's XML-based\n",
+ "[MJCF](http://www.mujoco.org/book/modeling.html) physics modeling language. The\n",
+ "goal of the library is to allow users to easily interact with and modify MJCF\n",
+ "models in Python, similarly to what the JavaScript DOM does for HTML.\n",
+ "\n",
+ "A key feature of this library is the ability to easily compose multiple separate\n",
+ "MJCF models into a larger one. Disambiguation of duplicated names from different\n",
+ "models, or multiple instances of the same model, is handled automatically.\n",
+ "\n",
+ "One typical use case is when we want robots with a variable number of joints. This is a fundamental change to the kinematics, requiring a new XML descriptor and new binary model to be compiled. \n",
+ "\n",
+ "The following snippets realise this scenario and provide a quick example of this library's use case."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gKny5EJ4uVzu"
+ },
+ "cell_type": "code",
+ "source": [
+ "class Leg(object):\n",
+ " \"\"\"A 2-DoF leg with position actuators.\"\"\"\n",
+ " def __init__(self, length, rgba):\n",
+ " self.model = mjcf.RootElement()\n",
+ "\n",
+ " # Defaults:\n",
+ " self.model.default.joint.damping = 2\n",
+ " self.model.default.joint.type = 'hinge'\n",
+ " self.model.default.geom.type = 'capsule'\n",
+ " self.model.default.geom.rgba = rgba # Continued below...\n",
+ "\n",
+ " # Thigh:\n",
+ " self.thigh = self.model.worldbody.add('body')\n",
+ " self.hip = self.thigh.add('joint', axis=[0, 0, 1])\n",
+ " self.thigh.add('geom', fromto=[0, 0, 0, length, 0, 0], size=[length/4])\n",
+ "\n",
+ " # Hip:\n",
+ " self.shin = self.thigh.add('body', pos=[length, 0, 0])\n",
+ " self.knee = self.shin.add('joint', axis=[0, 1, 0])\n",
+ " self.shin.add('geom', fromto=[0, 0, 0, 0, 0, -length], size=[length/5])\n",
+ "\n",
+ " # Position actuators:\n",
+ " self.model.actuator.add('position', joint=self.hip, kp=10)\n",
+ " self.model.actuator.add('position', joint=self.knee, kp=10)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "cFJerI--UtTy"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The `Leg` class describes an abstract articulated leg, with two joints and corresponding proportional-derivative actuators. \n",
+ "\n",
+ "Note that:\n",
+ "\n",
+ "- MJCF attributes correspond directly to arguments of the `add()` method.\n",
+ "- When referencing elements, e.g when specifying the joint to which an actuator is attached, the MJCF element itself is used, rather than the name string."
+ ]
+ },
+ {
+ "metadata": {
+ "cellView": "both",
+ "id": "SESlL_TidKHx"
+ },
+ "cell_type": "code",
+ "source": [
+ "BODY_RADIUS = 0.1\n",
+ "BODY_SIZE = (BODY_RADIUS, BODY_RADIUS, BODY_RADIUS / 2)\n",
+ "random_state = np.random.RandomState(42)\n",
+ "\n",
+ "def make_creature(num_legs):\n",
+ " \"\"\"Constructs a creature with `num_legs` legs.\"\"\"\n",
+ " rgba = random_state.uniform([0, 0, 0, 1], [1, 1, 1, 1])\n",
+ " model = mjcf.RootElement()\n",
+ " model.compiler.angle = 'radian' # Use radians.\n",
+ "\n",
+ " # Make the torso geom.\n",
+ " model.worldbody.add(\n",
+ " 'geom', name='torso', type='ellipsoid', size=BODY_SIZE, rgba=rgba)\n",
+ "\n",
+ " # Attach legs to equidistant sites on the circumference.\n",
+ " for i in range(num_legs):\n",
+ " theta = 2 * i * np.pi / num_legs\n",
+ " hip_pos = BODY_RADIUS * np.array([np.cos(theta), np.sin(theta), 0])\n",
+ " hip_site = model.worldbody.add('site', pos=hip_pos, euler=[0, 0, theta])\n",
+ " leg = Leg(length=BODY_RADIUS, rgba=rgba)\n",
+ " hip_site.attach(leg.model)\n",
+ "\n",
+ " return model"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "elyuJiI3U3kM"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The `make_creature` function uses PyMJCF's `attach()` method to procedurally attach legs to the torso. Note that at this stage both the torso and hip attachment sites are children of the `worldbody`, since their parent body has yet to be instantiated. We'll now make an arena with a chequered floor and two lights, and place our creatures in a grid."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "F7_Tx9P9U_VJ"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Six Creatures on a floor.{vertical-output: true}\n",
+ "\n",
+ "arena = mjcf.RootElement()\n",
+ "chequered = arena.asset.add('texture', type='2d', builtin='checker', width=300,\n",
+ " height=300, rgb1=[.2, .3, .4], rgb2=[.3, .4, .5])\n",
+ "grid = arena.asset.add('material', name='grid', texture=chequered,\n",
+ " texrepeat=[5, 5], reflectance=.2)\n",
+ "arena.worldbody.add('geom', type='plane', size=[2, 2, .1], material=grid)\n",
+ "for x in [-2, 2]:\n",
+ " arena.worldbody.add('light', pos=[x, -1, 3], dir=[-x, 1, -2])\n",
+ "\n",
+ "# Instantiate 6 creatures with 3 to 8 legs.\n",
+ "creatures = [make_creature(num_legs=num_legs) for num_legs in range(3, 9)]\n",
+ "\n",
+ "# Place them on a grid in the arena.\n",
+ "height = .15\n",
+ "grid = 5 * BODY_RADIUS\n",
+ "xpos, ypos, zpos = np.meshgrid([-grid, 0, grid], [0, grid], [height])\n",
+ "for i, model in enumerate(creatures):\n",
+ " # Place spawn sites on a grid.\n",
+ " spawn_pos = (xpos.flat[i], ypos.flat[i], zpos.flat[i])\n",
+ " spawn_site = arena.worldbody.add('site', pos=spawn_pos, group=3)\n",
+ " # Attach to the arena at the spawn sites, with a free joint.\n",
+ " spawn_site.attach(model).add('freejoint')\n",
+ "\n",
+ "# Instantiate the physics and render.\n",
+ "physics = mjcf.Physics.from_mjcf_model(arena)\n",
+ "PIL.Image.fromarray(physics.render())"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "cMfDaD7PfuoI"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Multi-legged creatures, ready to roam! Let's inject some controls and watch them move. We'll generate a sinusoidal open-loop control signal of fixed frequency and random phase, recording both video frames and the horizontal positions of the torso geoms, in order to plot the movement trajectories."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "8Gx39DMEUZDt"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Video of the movement{vertical-output: true}\n",
+ "#@test {\"timeout\": 600}\n",
+ "\n",
+ "duration = 10 # (Seconds)\n",
+ "framerate = 30 # (Hz)\n",
+ "video = []\n",
+ "pos_x = []\n",
+ "pos_y = []\n",
+ "torsos = [] # List of torso geom elements.\n",
+ "actuators = [] # List of actuator elements.\n",
+ "for creature in creatures:\n",
+ " torsos.append(creature.find('geom', 'torso'))\n",
+ " actuators.extend(creature.find_all('actuator'))\n",
+ "\n",
+ "# Control signal frequency, phase, amplitude.\n",
+ "freq = 5\n",
+ "phase = 2 * np.pi * random_state.rand(len(actuators))\n",
+ "amp = 0.9\n",
+ "\n",
+ "# Simulate, saving video frames and torso locations.\n",
+ "physics.reset()\n",
+ "while physics.data.time \u003c duration:\n",
+ " # Inject controls and step the physics.\n",
+ " physics.bind(actuators).ctrl = amp * np.sin(freq * physics.data.time + phase)\n",
+ " physics.step()\n",
+ "\n",
+ " # Save torso horizontal positions using bind().\n",
+ " pos_x.append(physics.bind(torsos).xpos[:, 0].copy())\n",
+ " pos_y.append(physics.bind(torsos).xpos[:, 1].copy())\n",
+ "\n",
+ " # Save video frames.\n",
+ " if len(video) \u003c physics.data.time * framerate:\n",
+ " pixels = physics.render()\n",
+ " video.append(pixels.copy())\n",
+ "\n",
+ "display_video(video, framerate)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "u09JfenWYLZu"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Movement trajectories{vertical-output: true}\n",
+ "\n",
+ "creature_colors = physics.bind(torsos).rgba[:, :3]\n",
+ "fig, ax = plt.subplots(figsize=(4, 4))\n",
+ "ax.set_prop_cycle(color=creature_colors)\n",
+ "_ = ax.plot(pos_x, pos_y, linewidth=4)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "kggQyvNpf_Y9"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The plot above shows the corresponding movement trajectories of creature positions. Note how `physics.bind(torsos)` was used to access both `xpos` and `rgba` values. Once the `Physics` had been instantiated by `from_mjcf_model()`, the `bind()` method will expose both the associated `mjData` and `mjModel` fields of an `mjcf` element, providing unified access to all quantities in the simulation. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "wcRX_wu_8q8u"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Composer tutorial"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "1DMhNPE5tSdw"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "In this tutorial we will create a task requiring our \"creature\" above to press a colour-changing button on the floor with a prescribed force. We begin by implementing our creature as a `composer.Entity`:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WwfzIqgNuFKt"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title The `Creature` class\n",
+ "\n",
+ "\n",
+ "class Creature(composer.Entity):\n",
+ " \"\"\"A multi-legged creature derived from `composer.Entity`.\"\"\"\n",
+ " def _build(self, num_legs):\n",
+ " self._model = make_creature(num_legs)\n",
+ "\n",
+ " def _build_observables(self):\n",
+ " return CreatureObservables(self)\n",
+ "\n",
+ " @property\n",
+ " def mjcf_model(self):\n",
+ " return self._model\n",
+ "\n",
+ " @property\n",
+ " def actuators(self):\n",
+ " return tuple(self._model.find_all('actuator'))\n",
+ "\n",
+ "\n",
+ "# Add simple observable features for joint angles and velocities.\n",
+ "class CreatureObservables(composer.Observables):\n",
+ "\n",
+ " @composer.observable\n",
+ " def joint_positions(self):\n",
+ " all_joints = self._entity.mjcf_model.find_all('joint')\n",
+ " return observable.MJCFFeature('qpos', all_joints)\n",
+ "\n",
+ " @composer.observable\n",
+ " def joint_velocities(self):\n",
+ " all_joints = self._entity.mjcf_model.find_all('joint')\n",
+ " return observable.MJCFFeature('qvel', all_joints)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "CXZOBK6RkjxH"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The `Creature` Entity includes generic Observables for joint angles and velocities. Because `find_all()` is called on the `Creature`'s MJCF model, it will only return the creature's leg joints, and not the \"free\" joint with which it will be attached to the world. Note that Composer Entities should override the `_build` and `_build_observables` methods rather than `__init__`. The implementation of `__init__` in the base class calls `_build` and `_build_observables`, in that order, to ensure that the entity's MJCF model is created before its observables. This was a design choice which allows the user to refer to an observable as an attribute (`entity.observables.foo`) while still making it clear which attributes are observables. The stateful `Button` class derives from `composer.Entity` and implements the `initialize_episode` and `after_substep` callbacks."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BE9VU2EOvR-u"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title The `Button` class\n",
+ "\n",
+ "NUM_SUBSTEPS = 25 # The number of physics substeps per control timestep.\n",
+ "\n",
+ "\n",
+ "class Button(composer.Entity):\n",
+ " \"\"\"A button Entity which changes colour when pressed with certain force.\"\"\"\n",
+ " def _build(self, target_force_range=(5, 10)):\n",
+ " self._min_force, self._max_force = target_force_range\n",
+ " self._mjcf_model = mjcf.RootElement()\n",
+ " self._geom = self._mjcf_model.worldbody.add(\n",
+ " 'geom', type='cylinder', size=[0.25, 0.02], rgba=[1, 0, 0, 1])\n",
+ " self._site = self._mjcf_model.worldbody.add(\n",
+ " 'site', type='cylinder', size=self._geom.size*1.01, rgba=[1, 0, 0, 0])\n",
+ " self._sensor = self._mjcf_model.sensor.add('touch', site=self._site)\n",
+ " self._num_activated_steps = 0\n",
+ "\n",
+ " def _build_observables(self):\n",
+ " return ButtonObservables(self)\n",
+ "\n",
+ " @property\n",
+ " def mjcf_model(self):\n",
+ " return self._mjcf_model\n",
+ " # Update the activation (and colour) if the desired force is applied.\n",
+ " def _update_activation(self, physics):\n",
+ " current_force = physics.bind(self.touch_sensor).sensordata[0]\n",
+ " self._is_activated = (current_force \u003e= self._min_force and\n",
+ " current_force \u003c= self._max_force)\n",
+ " physics.bind(self._geom).rgba = (\n",
+ " [0, 1, 0, 1] if self._is_activated else [1, 0, 0, 1])\n",
+ " self._num_activated_steps += int(self._is_activated)\n",
+ "\n",
+ " def initialize_episode(self, physics, random_state):\n",
+ " self._reward = 0.0\n",
+ " self._num_activated_steps = 0\n",
+ " self._update_activation(physics)\n",
+ "\n",
+ " def after_substep(self, physics, random_state):\n",
+ " self._update_activation(physics)\n",
+ "\n",
+ " @property\n",
+ " def touch_sensor(self):\n",
+ " return self._sensor\n",
+ "\n",
+ " @property\n",
+ " def num_activated_steps(self):\n",
+ " return self._num_activated_steps\n",
+ "\n",
+ "\n",
+ "class ButtonObservables(composer.Observables):\n",
+ " \"\"\"A touch sensor which averages contact force over physics substeps.\"\"\"\n",
+ " @composer.observable\n",
+ " def touch_force(self):\n",
+ " return observable.MJCFFeature('sensordata', self._entity.touch_sensor,\n",
+ " buffer_size=NUM_SUBSTEPS, aggregator='mean')"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "D9vB5nCwkyIW"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Note how the Button counts the number of sub-steps during which it is pressed with the desired force. It also exposes an `Observable` of the force being applied to the button, whose value is an average of the readings over the physics time-steps.\n",
+ "\n",
+ "We import some `variation` modules and an arena factory:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "aDTTQMHtVawM"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Random initialiser using `composer.variation`\n",
+ "\n",
+ "\n",
+ "class UniformCircle(variation.Variation):\n",
+ " \"\"\"A uniformly sampled horizontal point on a circle of radius `distance`.\"\"\"\n",
+ " def __init__(self, distance):\n",
+ " self._distance = distance\n",
+ " self._heading = distributions.Uniform(0, 2*np.pi)\n",
+ "\n",
+ " def __call__(self, initial_value=None, current_value=None, random_state=None):\n",
+ " distance, heading = variation.evaluate(\n",
+ " (self._distance, self._heading), random_state=random_state)\n",
+ " return (distance*np.cos(heading), distance*np.sin(heading), 0)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "dgZwP-pvxJdt"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title The `PressWithSpecificForce` task\n",
+ "\n",
+ "\n",
+ "class PressWithSpecificForce(composer.Task):\n",
+ "\n",
+ " def __init__(self, creature):\n",
+ " self._creature = creature\n",
+ " self._arena = floors.Floor()\n",
+ " self._arena.add_free_entity(self._creature)\n",
+ " self._arena.mjcf_model.worldbody.add('light', pos=(0, 0, 4))\n",
+ " self._button = Button()\n",
+ " self._arena.attach(self._button)\n",
+ "\n",
+ " # Configure initial poses\n",
+ " self._creature_initial_pose = (0, 0, 0.15)\n",
+ " button_distance = distributions.Uniform(0.5, .75)\n",
+ " self._button_initial_pose = UniformCircle(button_distance)\n",
+ "\n",
+ " # Configure variators\n",
+ " self._mjcf_variator = variation.MJCFVariator()\n",
+ " self._physics_variator = variation.PhysicsVariator()\n",
+ "\n",
+ " # Configure and enable observables\n",
+ " pos_corrptor = noises.Additive(distributions.Normal(scale=0.01))\n",
+ " self._creature.observables.joint_positions.corruptor = pos_corrptor\n",
+ " self._creature.observables.joint_positions.enabled = True\n",
+ " vel_corruptor = noises.Multiplicative(distributions.LogNormal(sigma=0.01))\n",
+ " self._creature.observables.joint_velocities.corruptor = vel_corruptor\n",
+ " self._creature.observables.joint_velocities.enabled = True\n",
+ " self._button.observables.touch_force.enabled = True\n",
+ "\n",
+ " def to_button(physics):\n",
+ " button_pos, _ = self._button.get_pose(physics)\n",
+ " return self._creature.global_vector_to_local_frame(physics, button_pos)\n",
+ "\n",
+ " self._task_observables = {}\n",
+ " self._task_observables['button_position'] = observable.Generic(to_button)\n",
+ "\n",
+ " for obs in self._task_observables.values():\n",
+ " obs.enabled = True\n",
+ "\n",
+ " self.control_timestep = NUM_SUBSTEPS * self.physics_timestep\n",
+ "\n",
+ " @property\n",
+ " def root_entity(self):\n",
+ " return self._arena\n",
+ "\n",
+ " @property\n",
+ " def task_observables(self):\n",
+ " return self._task_observables\n",
+ "\n",
+ " def initialize_episode_mjcf(self, random_state):\n",
+ " self._mjcf_variator.apply_variations(random_state)\n",
+ "\n",
+ " def initialize_episode(self, physics, random_state):\n",
+ " self._physics_variator.apply_variations(physics, random_state)\n",
+ " creature_pose, button_pose = variation.evaluate(\n",
+ " (self._creature_initial_pose, self._button_initial_pose),\n",
+ " random_state=random_state)\n",
+ " self._creature.set_pose(physics, position=creature_pose)\n",
+ " self._button.set_pose(physics, position=button_pose)\n",
+ "\n",
+ " def get_reward(self, physics):\n",
+ " return self._button.num_activated_steps / NUM_SUBSTEPS"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "dRuuZdLpthbv"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Instantiating an environment{vertical-output: true}\n",
+ "\n",
+ "creature = Creature(num_legs=4)\n",
+ "task = PressWithSpecificForce(creature)\n",
+ "env = composer.Environment(task, random_state=np.random.RandomState(42))\n",
+ "\n",
+ "env.reset()\n",
+ "PIL.Image.fromarray(env.physics.render())"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "giTL_6euZFlw"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# The *Control Suite*"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "zfIcrDECtdB2"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The **Control Suite** is a set of stable, well-tested tasks designed to serve as a benchmark for continuous control learning agents. Tasks are written using the basic MuJoCo wrapper interface. Standardised action, observation and reward structures make suite-wide benchmarking simple and learning curves easy to interpret. Control Suite domains are not meant to be modified, in order to facilitate benchmarking. For full details regarding benchmarking, please refer to our original [publication](https://arxiv.org/abs/1801.00690).\n",
+ "\n",
+ "A video of solved benchmark tasks is [available here](https://www.youtube.com/watch?v=rAai4QzcYbs\u0026feature=youtu.be).\n",
+ "\n",
+ "The suite come with convenient module level tuples for iterating over tasks:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "a_whTJG8uTp1"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Iterating over tasks{vertical-output: true}\n",
+ "\n",
+ "max_len = max(len(d) for d, _ in suite.BENCHMARKING)\n",
+ "for domain, task in suite.BENCHMARKING:\n",
+ " print(f'{domain:\u003c{max_len}} {task}')"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "qN8y3etfZFly"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Loading and simulating a `suite` task{vertical-output: true}\n",
+ "\n",
+ "# Load the environment\n",
+ "random_state = np.random.RandomState(42)\n",
+ "env = suite.load('hopper', 'stand', task_kwargs={'random': random_state})\n",
+ "\n",
+ "# Simulate episode with random actions\n",
+ "duration = 4 # Seconds\n",
+ "frames = []\n",
+ "ticks = []\n",
+ "rewards = []\n",
+ "observations = []\n",
+ "\n",
+ "spec = env.action_spec()\n",
+ "time_step = env.reset()\n",
+ "\n",
+ "while env.physics.data.time \u003c duration:\n",
+ "\n",
+ " action = random_state.uniform(spec.minimum, spec.maximum, spec.shape)\n",
+ " time_step = env.step(action)\n",
+ "\n",
+ " camera0 = env.physics.render(camera_id=0, height=200, width=200)\n",
+ " camera1 = env.physics.render(camera_id=1, height=200, width=200)\n",
+ " frames.append(np.hstack((camera0, camera1)))\n",
+ " rewards.append(time_step.reward)\n",
+ " observations.append(copy.deepcopy(time_step.observation))\n",
+ " ticks.append(env.physics.data.time)\n",
+ "\n",
+ "html_video = display_video(frames, framerate=1./env.control_timestep())\n",
+ "\n",
+ "# Show video and plot reward and observations\n",
+ "num_sensors = len(time_step.observation)\n",
+ "\n",
+ "_, ax = plt.subplots(1 + num_sensors, 1, sharex=True, figsize=(4, 8))\n",
+ "ax[0].plot(ticks, rewards)\n",
+ "ax[0].set_ylabel('reward')\n",
+ "ax[-1].set_xlabel('time')\n",
+ "\n",
+ "for i, key in enumerate(time_step.observation):\n",
+ " data = np.asarray([observations[j][key] for j in range(len(observations))])\n",
+ " ax[i+1].plot(ticks, data, label=key)\n",
+ " ax[i+1].set_ylabel(key)\n",
+ "\n",
+ "html_video"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "ggVbQr5hZFl5"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Visualizing an initial state of one task per domain in the Control Suite\n",
+ "#@test {\"timeout\": 180}\n",
+ "domains_tasks = {domain: task for domain, task in suite.ALL_TASKS}\n",
+ "random_state = np.random.RandomState(42)\n",
+ "num_domains = len(domains_tasks)\n",
+ "n_col = num_domains // int(np.sqrt(num_domains))\n",
+ "n_row = num_domains // n_col + int(0 \u003c num_domains % n_col)\n",
+ "_, ax = plt.subplots(n_row, n_col, figsize=(12, 12))\n",
+ "for a in ax.flat:\n",
+ " a.axis('off')\n",
+ " a.grid(False)\n",
+ "\n",
+ "print(f'Iterating over all {num_domains} domains in the Suite:')\n",
+ "for j, [domain, task] in enumerate(domains_tasks.items()):\n",
+ " print(domain, task)\n",
+ "\n",
+ " env = suite.load(domain, task, task_kwargs={'random': random_state})\n",
+ " timestep = env.reset()\n",
+ " pixels = env.physics.render(height=200, width=200, camera_id=0)\n",
+ "\n",
+ " ax.flat[j].imshow(pixels)\n",
+ " ax.flat[j].set_title(domain + ': ' + task)\n",
+ "\n",
+ "clear_output()"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "JHSvxHiaopDb"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Locomotion"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "yTn3C03dpHzL"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Humanoid running along corridor with obstacles\n",
+ "\n",
+ "As an illustrative example of using the Locomotion infrastructure to build an RL environment, consider placing a humanoid in a corridor with walls, and a task specifying that the humanoid will be rewarded for running along this corridor, navigating around the wall obstacles using vision. We instantiate the environment as a composition of the Walker, Arena, and Task as follows. First, we build a position-controlled CMU humanoid walker. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gE8rrB7PpN9X"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title A position controlled `cmu_humanoid`\n",
+ "\n",
+ "walker = cmu_humanoid.CMUHumanoidPositionControlledV2020(\n",
+ " observable_options={'egocentric_camera': dict(enabled=True)})"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "3fYbaDflBrgE"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Next, we construct a corridor-shaped arena that is obstructed by walls."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "t-O17Fnm3E6R"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title A corridor arena with wall obstacles\n",
+ "\n",
+ "arena = corridor_arenas.WallsCorridor(\n",
+ " wall_gap=3.,\n",
+ " wall_width=distributions.Uniform(2., 3.),\n",
+ " wall_height=distributions.Uniform(2.5, 3.5),\n",
+ " corridor_width=4.,\n",
+ " corridor_length=30.,\n",
+ ")"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "970nN38eBx-R"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The task constructor places the walker in the arena."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "dz4Jy2UGpQ4Z"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title A task to navigate the arena\n",
+ "\n",
+ "task = corridor_tasks.RunThroughCorridor(\n",
+ " walker=walker,\n",
+ " arena=arena,\n",
+ " walker_spawn_position=(0.5, 0, 0),\n",
+ " target_velocity=3.0,\n",
+ " physics_timestep=0.005,\n",
+ " control_timestep=0.03,\n",
+ ")"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "r-Oy-qTSB4HW"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Finally, a task that rewards the agent for running down the corridor at a specific velocity is instantiated as a `composer.Environment`."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "sQXlaEZk3ytl"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title The `RunThroughCorridor` environment\n",
+ "\n",
+ "env = composer.Environment(\n",
+ " task=task,\n",
+ " time_limit=10,\n",
+ " random_state=np.random.RandomState(42),\n",
+ " strip_singleton_obs_buffer_dim=True,\n",
+ ")\n",
+ "env.reset()\n",
+ "pixels = []\n",
+ "for camera_id in range(3):\n",
+ " pixels.append(env.physics.render(camera_id=camera_id, width=240))\n",
+ "PIL.Image.fromarray(np.hstack(pixels))"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "HuuQLm8YopDe"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Multi-Agent Soccer"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "OPNshDEEopDf"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Building on Composer and Locomotion libraries, the Multi-agent soccer environments, introduced in [this paper](https://arxiv.org/abs/1902.07151), follow a consistent task structure of Walkers, Arena, and Task where instead of a single walker, we inject multiple walkers that can interact with each other physically in the same scene. The code snippet below shows how to instantiate a 2-vs-2 Multi-agent Soccer environment with the simple, 5 degree-of-freedom `BoxHead` walker type."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "zAb3je0DAeQo"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title 2-v-2 `Boxhead` soccer\n",
+ "\n",
+ "random_state = np.random.RandomState(42)\n",
+ "env = soccer.load(\n",
+ " team_size=2,\n",
+ " time_limit=45.,\n",
+ " random_state=random_state,\n",
+ " disable_walker_contacts=False,\n",
+ " walker_type=soccer.WalkerType.BOXHEAD,\n",
+ ")\n",
+ "env.reset()\n",
+ "pixels = []\n",
+ "# Select a random subset of 6 cameras (soccer envs have lots of cameras)\n",
+ "cameras = random_state.choice(env.physics.model.ncam, 6, replace=False)\n",
+ "for camera_id in cameras:\n",
+ " pixels.append(env.physics.render(camera_id=camera_id, width=240))\n",
+ "image = np.vstack((np.hstack(pixels[:3]), np.hstack(pixels[3:])))\n",
+ "PIL.Image.fromarray(image)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "J_5C2k0NGvxE"
+ },
+ "cell_type": "markdown",
+ "source": [
+ " It can trivially be replaced by e.g. the `WalkerType.ANT` walker:"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WDIGodhBG-Mn"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title 3-v-3 `Ant` soccer\n",
+ "\n",
+ "random_state = np.random.RandomState(42)\n",
+ "env = soccer.load(\n",
+ " team_size=3,\n",
+ " time_limit=45.,\n",
+ " random_state=random_state,\n",
+ " disable_walker_contacts=False,\n",
+ " walker_type=soccer.WalkerType.ANT,\n",
+ ")\n",
+ "env.reset()\n",
+ "\n",
+ "pixels = []\n",
+ "cameras = random_state.choice(env.physics.model.ncam, 6, replace=False)\n",
+ "for camera_id in cameras:\n",
+ " pixels.append(env.physics.render(camera_id=camera_id, width=240))\n",
+ "image = np.vstack((np.hstack(pixels[:3]), np.hstack(pixels[3:])))\n",
+ "PIL.Image.fromarray(image)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "MvK9BW4A5c9p"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Manipulation"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "jPt27n2Dch_m"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The `manipulation` module provides a robotic arm, a set of simple objects, and tools for building reward functions for manipulation tasks."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "cZxmJoovahCA"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Listing all `manipulation` tasks{vertical-output: true}\n",
+ "\n",
+ "# `ALL` is a tuple containing the names of all of the environments in the suite.\n",
+ "print('\\n'.join(manipulation.ALL))"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "oj0cJFlR5nTS"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Listing `manipulation` tasks that use vision{vertical-output: true}\n",
+ "print('\\n'.join(manipulation.get_environments_by_tag('vision')))"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "id": "e_6q4FqFIKxy"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Loading and simulating a `manipulation` task{vertical-output: true}\n",
+ "\n",
+ "env = manipulation.load('stack_2_of_3_bricks_random_order_vision', seed=42)\n",
+ "action_spec = env.action_spec()\n",
+ "\n",
+ "def sample_random_action():\n",
+ " return env.random_state.uniform(\n",
+ " low=action_spec.minimum,\n",
+ " high=action_spec.maximum,\n",
+ " ).astype(action_spec.dtype, copy=False)\n",
+ "\n",
+ "# Step the environment through a full episode using random actions and record\n",
+ "# the camera observations.\n",
+ "frames = []\n",
+ "timestep = env.reset()\n",
+ "frames.append(timestep.observation['front_close'])\n",
+ "while not timestep.last():\n",
+ " timestep = env.step(sample_random_action())\n",
+ " frames.append(timestep.observation['front_close'])\n",
+ "all_frames = np.concatenate(frames, axis=0)\n",
+ "display_video(all_frames, 30)"
+ ],
+ "outputs": [],
+ "execution_count": null
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [
+ "YkBQUjm6gbGF",
+ "YvyGCsgSCxHQ",
+ "wtDN43hIJh2C",
+ "jZXz9rPYGA-Y",
+ "MdUF2UYmR4TA",
+ "nQ8XOnRQx7T1",
+ "gf9h_wi9weet",
+ "NCcZxrDDB1Cj",
+ "t5hY0fyXFLcf",
+ "Z0UodCxS_v49",
+ "22ENjtVuhwsm",
+ "SHppAOjvSupc",
+ "rRuFKD2ubPgu",
+ "UAMItwu8e1WR",
+ "wcRX_wu_8q8u",
+ "giTL_6euZFlw",
+ "JHSvxHiaopDb",
+ "yTn3C03dpHzL",
+ "HuuQLm8YopDe",
+ "MvK9BW4A5c9p"
+ ],
+ "last_runtime": {
+ "build_target": "",
+ "kind": "local"
+ },
+ "name": "dm_control",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}