-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathstorage_wrapper.py
More file actions
223 lines (203 loc) · 7.58 KB
/
storage_wrapper.py
File metadata and controls
223 lines (203 loc) · 7.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import io
import operator
from concurrent.futures import ThreadPoolExecutor, wait
from itertools import chain
from queue import Queue
from typing import Any, Optional
from uuid import uuid4
import gymnasium as gym
import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import simplejpeg
from PIL import Image
class StorageWrapper(gym.Wrapper):
QueueSentinel = None
def __init__(
self,
env: gym.Env,
base_dir: str,
instruction: str,
batch_size: int = 32,
schema: Optional[pa.Schema] = None,
always_record: bool = False,
basename_template: Optional[str] = None,
max_rows_per_group: Optional[int] = None,
max_rows_per_file: Optional[int] = None,
):
"""
Asynchronously log environment transitions to a Parquet
dataset on disk.
Observation handling:
- Expects observations to be dictionaries.
- RGB camera frames are JPEG-encoded.
- Numpy arrays with ndim > 1 inside the observation dict are flattened
in-place, and their original shapes are stored alongside as
``"<key>_shape"``. Nested dicts are traversed recursively.
- Lists/tuples of arrays are not supported.
- ``close()`` must be called to flush the final batch.
Parameters
----------
env : gym.Env
The environment to wrap.
base_dir : str
Output directory where the Parquet dataset will be written.
batch_size : int
Number of transitions to accumulate before flushing a RecordBatch
to the writer queue.
schema : Optional[pa.Schema], default=None
Optional Arrow schema to enforce for all batches. If None, the schema
is inferred from the first flushed batch and then reused.
basename_template : Optional[str], default=None
Template controlling Parquet file basenames. Passed through to
``pyarrow.dataset.write_dataset``.
max_rows_per_group : Optional[int], default=None
Maximum row count per Parquet row group. Passed through to
``pyarrow.dataset.write_dataset``.
max_rows_per_file : Optional[int], default=None
Maximum row count per Parquet file. Passed through to
``pyarrow.dataset.write_dataset``.
"""
super().__init__(env)
self.base_dir = base_dir
self.batch_size = batch_size
self.schema = schema
self.basename_template = basename_template
self.max_rows_per_group = max_rows_per_group
self.max_rows_per_file = max_rows_per_file
self.buffer: list[dict[str, Any]] = []
self.step_cnt = 0
self._pause = not always_record
self.always_record = always_record
self.instruction = instruction
self._success = False
self._prev_action = None
self.thread_pool = ThreadPoolExecutor()
self.queue: Queue[pa.Table | pa.RecordBatch] = Queue(maxsize=2)
self.uuid = uuid4()
self._writer_future = self.thread_pool.submit(self._writer_worker)
def _generator_from_queue(self):
while (batch := self.queue.get()) is not self.QueueSentinel:
yield batch
def _writer_worker(self):
gen = self._generator_from_queue()
first = next(gen)
ds.write_dataset(
data=chain([first], gen),
base_dir=self.base_dir,
format="parquet",
schema=self.schema,
existing_data_behavior="overwrite_or_ignore",
basename_template=self.basename_template,
max_rows_per_group=self.max_rows_per_group,
max_rows_per_file=self.max_rows_per_file,
partitioning=ds.partitioning(
schema=pa.schema(fields=[pa.field("uuid", pa.string())]),
flavor="filename",
),
)
def _flush(self):
batch = pa.RecordBatch.from_pylist(self.buffer, schema=self.schema)
if self.schema is None:
self.schema = batch.schema
self.queue.put(batch)
self.buffer.clear()
def _flatten_arrays(self, d: dict[str, Any]):
# NOTE: list / tuples of arrays not supported
updates = {}
for k, v in d.items():
if isinstance(v, dict):
self._flatten_arrays(v)
elif isinstance(v, np.ndarray) and len(v.shape) > 1:
d[k] = v.flatten()
updates[f"{k}_shape"] = v.shape
d.update(updates)
def _encode_images(self, obs: dict[str, Any]):
# images
_ = [
*self.thread_pool.map(
lambda cam: operator.setitem(
obs["frames"][cam]["rgb"],
"data",
simplejpeg.encode_jpeg(np.ascontiguousarray(obs["frames"][cam]["rgb"]["data"])),
),
obs["frames"],
)
]
# depth
def to_tiff(depth_data):
img_bytes = io.BytesIO()
Image.fromarray(
depth_data.reshape((depth_data.shape[0], depth_data.shape[1])),
).save(
img_bytes, format="TIFF"
) # type: ignore
return img_bytes.getvalue() # type: ignore
_ = [
*self.thread_pool.map(
lambda cam: (
operator.setitem(
obs["frames"][cam]["depth"],
"data",
to_tiff(obs["frames"][cam]["depth"]["data"]),
)
if "depth" in obs["frames"][cam]
else None
),
obs["frames"],
)
]
def step(self, action):
# NOTE: expects the observation to be a dictionary
if self._writer_future.done():
exc = self._writer_future.exception()
assert exc is not None
msg = "Writer thread failed"
raise RuntimeError(msg) from exc
obs, reward, terminated, truncated, info = self.env.step(action)
if not self._pause:
assert isinstance(obs, dict)
if "frames" in obs:
self._encode_images(obs)
self._flatten_arrays(obs)
if info.get("success"):
self.success()
self.buffer.append(
{
"obs": obs,
"reward": reward,
"step": self.step_cnt,
"uuid": self.uuid.hex,
"success": self._success,
"action": self._prev_action,
"instruction": self.instruction,
}
)
self._prev_action = action
self.step_cnt += 1
if len(self.buffer) == self.batch_size:
self._flush()
return obs, reward, terminated, truncated, info
def success(self):
self._success = True
def stop_record(self):
self._pause = True
if len(self.buffer) > 0:
self._flush()
def start_record(self):
self._pause = False
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
if len(self.buffer) > 0:
self._flush()
self._pause = not self.always_record
self._success = False
self._prev_action = None
obs, info = self.env.reset()
self.step_cnt = 0
self.uuid = uuid4()
return obs, info
def close(self):
if len(self.buffer) > 0:
self._flush()
self.queue.put(self.QueueSentinel)
wait([self._writer_future])