diff --git a/demonstrations/demo_store.py b/demonstrations/demo_store.py index 1509004..b9cc0b3 100644 --- a/demonstrations/demo_store.py +++ b/demonstrations/demo_store.py @@ -1,24 +1,23 @@ """Script for uploading the collected demos.""" import logging import os +import tempfile import warnings import zipfile +from copy import deepcopy +from pathlib import Path from typing import Optional import numpy as np -import tempfile -from pathlib import Path -from copy import deepcopy - import wget from tqdm import tqdm from bigym.bigym_env import CONTROL_FREQUENCY_MAX from bigym.const import CACHE_PATH, DEMO_RELEASES, DEMO_VERSION from demonstrations.const import SAFETENSORS_SUFFIX -from demonstrations.utils import Metadata, ObservationMode from demonstrations.demo import Demo, LightweightDemo from demonstrations.demo_converter import DemoConverter +from demonstrations.utils import Metadata, ObservationMode class DemoNotFoundError(Exception): @@ -165,9 +164,10 @@ def _get_demos(self, demos_dir: Path, amount: int) -> list[Demo]: files = list(demos_dir.glob(f"*{SAFETENSORS_SUFFIX}")) if amount > len(files): raise TooManyDemosRequestedError(amount, len(files)) - elif amount > 0: - files = files[:amount] + np.random.shuffle(files) + if amount > 0: + files = files[:amount] return [Demo.from_safetensors(file) for file in files] def _get_demos_count(self, demos_dir: Path) -> int: