diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 47b7b2eb17..43707bc87b 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -40,7 +40,7 @@ jobs: - name: setup python uses: actions/setup-python@v2 with: - python-version: '3.7' + python-version: '3.8' architecture: x64 - name: install deps diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 152071b5e6..46fc96103b 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -24,14 +24,14 @@ on: - reopened - edited jobs: - validate-naming-convention: - name: Pull Request's title matches naming convention - runs-on: ubuntu-latest - steps: - - uses: deepakputhraya/action-pr-title@master - with: - regex: '^\[(?:feat|fix|doc|refactor|deprecation)\]\s[A-Z].*(? Tuple[bool, str]: import boto3 except ImportError: raise RuntimeError( - "This command requires 'boto3' to be installed. " 'Please install it with command: \n pip install boto3' + "This command requires 'boto3' to be installed. Please install it with command: \n pip install boto3" ) try: diff --git a/aim/cli/server/commands.py b/aim/cli/server/commands.py index b84c1ae68b..7c1587021d 100644 --- a/aim/cli/server/commands.py +++ b/aim/cli/server/commands.py @@ -95,5 +95,5 @@ def server(host, port, repo, ssl_keyfile, ssl_certfile, base_path, log_level, de ) exec_cmd(cmd, stream_output=True) except ShellCommandException: - click.echo('Failed to run Aim Tracking Server. ' 'Please see the logs above for details.') + click.echo('Failed to run Aim Tracking Server. Please see the logs above for details.') exit(1) diff --git a/aim/cli/storage/commands.py b/aim/cli/storage/commands.py index 60ec266142..3210f7a692 100644 --- a/aim/cli/storage/commands.py +++ b/aim/cli/storage/commands.py @@ -56,7 +56,7 @@ def to_3_11(ctx, hashes, yes): try: run = Run(run_hash, repo=repo) if run.check_metrics_version(): - backup_run(run) + backup_run(repo, run.hash) run.update_metrics() index_manager.index(run_hash) else: diff --git a/aim/cli/up/commands.py b/aim/cli/up/commands.py index 2044c75e46..6ad5e6e733 100644 --- a/aim/cli/up/commands.py +++ b/aim/cli/up/commands.py @@ -11,7 +11,9 @@ get_repo_instance, set_log_level, ) +from aim.sdk.index_manager import RepoIndexManager from aim.sdk.repo import Repo +from aim.sdk.run_status_manager import RunStatusManager from aim.sdk.utils import clean_repo_path from aim.web.configs import ( AIM_ENV_MODE_KEY, @@ -29,7 +31,7 @@ @click.command() @click.option('-h', '--host', default=AIM_UI_DEFAULT_HOST, type=str) @click.option('-p', '--port', default=AIM_UI_DEFAULT_PORT, type=int) -@click.option('-w', '--workers', default=1, type=int) +@click.option('-w', '--workers', default=2, type=int) @click.option('--uds', required=False, type=click.Path(exists=False, file_okay=True, dir_okay=False, readable=True)) @click.option('--repo', required=False, type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True)) @click.option('--tf_logs', type=click.Path(exists=True, readable=True)) @@ -96,7 +98,7 @@ def up( db_cmd = build_db_upgrade_command() exec_cmd(db_cmd, stream_output=True) except ShellCommandException: - click.echo('Failed to initialize Aim DB. ' 'Please see the logs above for details.') + click.echo('Failed to initialize Aim DB. Please see the logs above for details.') return if port == 0: @@ -122,6 +124,11 @@ def up( if profiler: os.environ[AIM_PROFILER_KEY] = '1' + index_mng = RepoIndexManager.get_index_manager(repo_inst) + index_mng.start() + + run_status_mng = RunStatusManager(repo_inst) + run_status_mng.start() try: server_cmd = build_uvicorn_command( 'aim.web.run:app', diff --git a/aim/distributed_hugging_face.py b/aim/distributed_hugging_face.py new file mode 100644 index 0000000000..cbb9f8eec8 --- /dev/null +++ b/aim/distributed_hugging_face.py @@ -0,0 +1,2 @@ +# Alias to SDK distributed hugging face interface +from aim.sdk.adapters.distributed_hugging_face import AimCallback # noqa: F401 diff --git a/aim/ext/notifier/notifier.py b/aim/ext/notifier/notifier.py index 1b237af620..ca73fe59ab 100644 --- a/aim/ext/notifier/notifier.py +++ b/aim/ext/notifier/notifier.py @@ -34,7 +34,7 @@ def notify(self, message: Optional[str] = None, **kwargs): except Exception as e: attempt += 1 if attempt == self.MAX_RETRIES: - logger.error(f'Notifier {sub} failed to send message "{message}". ' f'No retries left.') + logger.error(f'Notifier {sub} failed to send message "{message}". No retries left.') raise NotificationSendError(e) else: logger.error( diff --git a/aim/ext/sshfs/utils.py b/aim/ext/sshfs/utils.py index 3dba53168d..170f9d4fed 100644 --- a/aim/ext/sshfs/utils.py +++ b/aim/ext/sshfs/utils.py @@ -197,7 +197,7 @@ def unmount_remote_repo(mount_point: str, mount_root: str): if exit_code != 0: # in case of failure log warning so the user can unmount manually if needed logger.warning( - f'Could not unmount path: {mount_point}.\n' f'Please unmount manually using command:\n' f'{" ".join(cmd)}' + f'Could not unmount path: {mount_point}.\nPlease unmount manually using command:\n{" ".join(cmd)}' ) else: shutil.rmtree(mount_root) diff --git a/aim/ext/tensorboard_tracker/tracker.py b/aim/ext/tensorboard_tracker/tracker.py index 59af672f85..f902e08069 100644 --- a/aim/ext/tensorboard_tracker/tracker.py +++ b/aim/ext/tensorboard_tracker/tracker.py @@ -31,7 +31,7 @@ def _decode_histogram(value): # This is a bit weird but it seems the histogram counts is usually padded by 0 as tensorboard # only stores the right limits? - # See https://github.com/pytorch/pytorch/blob/7d2a18da0b3427fcbe44b461a0aa508194535885/torch/utils/tensorboard/summary.py#L390 # noqa + # See https://github.com/pytorch/pytorch/blob/7d2a18da0b3427fcbe44b461a0aa508194535885/torch/utils/tensorboard/summary.py#L390 bin_counts = bin_counts[1:] bin_range = (bucket_limits[0], bucket_limits[-1]) diff --git a/aim/ext/transport/handlers.py b/aim/ext/transport/handlers.py index 7915bc1055..23c9985d2b 100644 --- a/aim/ext/transport/handlers.py +++ b/aim/ext/transport/handlers.py @@ -51,14 +51,12 @@ def get_tree(**kwargs): name = kwargs['name'] sub = kwargs['sub'] read_only = kwargs['read_only'] - from_union = kwargs['from_union'] index = kwargs['index'] timeout = kwargs['timeout'] - no_cache = kwargs.get('no_cache', False) if index: return ResourceRef(repo._get_index_tree(name, timeout)) else: - return ResourceRef(repo.request_tree(name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache)) + return ResourceRef(repo.request_tree(name, sub, read_only=read_only)) def get_structured_run(hash_, read_only, created_at, **kwargs): diff --git a/aim/ext/transport/heartbeat.py b/aim/ext/transport/heartbeat.py index e8009576b1..d6390d63eb 100644 --- a/aim/ext/transport/heartbeat.py +++ b/aim/ext/transport/heartbeat.py @@ -15,10 +15,8 @@ class HeartbeatSender(object): HEARTBEAT_INTERVAL_DEFAULT = 10 NETWORK_CHECK_INTERVAL = 180 - NETWORK_UNSTABLE_WARNING_TEMPLATE = ( - 'Network connection between client `{}` ' 'and server `{}` appears to be unstable.' - ) - NETWORK_ABSENT_WARNING_TEMPLATE = 'Network connection between client `{}` ' 'and server `{}` appears to be absent.' + NETWORK_UNSTABLE_WARNING_TEMPLATE = 'Network connection between client `{}` and server `{}` appears to be unstable.' + NETWORK_ABSENT_WARNING_TEMPLATE = 'Network connection between client `{}` and server `{}` appears to be absent.' def __init__( self, @@ -118,7 +116,7 @@ def reset_responses(): class HeartbeatWatcher: - CLIENT_KEEP_ALIVE_TIME_DEFAULT = 30 * 60 # 30 minutes + CLIENT_KEEP_ALIVE_TIME_DEFAULT = 5 * 60 # 5 minutes def __init__(self, heartbeat_pool, keep_alive_time: Union[int, float] = CLIENT_KEEP_ALIVE_TIME_DEFAULT): self._heartbeat_pool = heartbeat_pool diff --git a/aim/ext/transport/message_utils.py b/aim/ext/transport/message_utils.py index 127b1f7e58..ceb52fac27 100644 --- a/aim/ext/transport/message_utils.py +++ b/aim/ext/transport/message_utils.py @@ -5,7 +5,8 @@ from typing import Iterator, Tuple from aim.storage.object import CustomObject -from aim.storage.treeutils import decode_tree, encode_tree # noqa +from aim.storage.treeutils import decode_tree as decode_tree +from aim.storage.treeutils import encode_tree as encode_tree from aim.storage.types import BLOB @@ -45,28 +46,23 @@ def pack_stream(tree: Iterator[Tuple[bytes, bytes]]) -> bytes: yield struct.pack('I', len(key)) + key + struct.pack('?', True) + struct.pack('I', len(val)) + val -def unpack_helper(msg: bytes) -> Tuple[bytes, bytes]: - (key_size,), tail = struct.unpack('I', msg[:4]), msg[4:] - key, tail = tail[:key_size], tail[key_size:] - (is_blob,), tail = struct.unpack('?', tail[:1]), tail[1:] - (value_size,), tail = struct.unpack('I', tail[:4]), tail[4:] - value, tail = tail[:value_size], tail[value_size:] - assert len(tail) == 0 - if is_blob: - yield key, BLOB(data=value) - else: - yield key, value - - def unpack_stream(stream) -> Tuple[bytes, bytes]: for msg in stream: - yield from unpack_helper(msg) + yield from unpack_args(msg) def raise_exception(server_exception): + from filelock import Timeout + module = importlib.import_module(server_exception.get('module_name')) exception = getattr(module, server_exception.get('class_name')) args = json.loads(server_exception.get('args') or []) + message = server_exception.get('message') + + # special handling for lock timeouts as they require lock argument which can't be passed over the network + if exception == Timeout: + raise Exception(message) + raise exception(*args) if args else exception() @@ -75,6 +71,7 @@ def build_exception(exception: Exception): 'module_name': exception.__class__.__module__, 'class_name': exception.__class__.__name__, 'args': json.dumps(exception.args), + 'message': str(exception), } diff --git a/aim/ext/transport/utils.py b/aim/ext/transport/utils.py index 037ad1262e..b556692fe3 100644 --- a/aim/ext/transport/utils.py +++ b/aim/ext/transport/utils.py @@ -13,7 +13,7 @@ def inner(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) - except exc_type as e: # noqa + except exc_type: if error_message is not None: logger.error(error_message) raise RuntimeError(error_message) diff --git a/aim/fastai.py b/aim/fastai.py index 2758828947..ab00bee140 100644 --- a/aim/fastai.py +++ b/aim/fastai.py @@ -1,2 +1,2 @@ # Alias to SDK fast.ai interface -from aim.sdk.adapters.fastai import AimCallback # noqa F401 +from aim.sdk.adapters.fastai import AimCallback as AimCallback diff --git a/aim/hf_dataset.py b/aim/hf_dataset.py index e629b15c93..00ecc9cc99 100644 --- a/aim/hf_dataset.py +++ b/aim/hf_dataset.py @@ -1,2 +1,2 @@ # Alias to SDK Hugging Face Datasets interface -from aim.sdk.objects.plugins.hf_datasets_metadata import HFDataset # noqa F401 +from aim.sdk.objects.plugins.hf_datasets_metadata import HFDataset as HFDataset diff --git a/aim/hugging_face.py b/aim/hugging_face.py index 9fbde32ec8..692ec24865 100644 --- a/aim/hugging_face.py +++ b/aim/hugging_face.py @@ -1,2 +1,2 @@ # Alias to SDK Hugging Face interface -from aim.sdk.adapters.hugging_face import AimCallback # noqa F401 +from aim.sdk.adapters.hugging_face import AimCallback as AimCallback diff --git a/aim/keras.py b/aim/keras.py index 3383dff655..e1c6ed28f1 100644 --- a/aim/keras.py +++ b/aim/keras.py @@ -1,2 +1,3 @@ # Alias to SDK Keras interface -from aim.sdk.adapters.keras import AimCallback, AimTracker # noqa F401 +from aim.sdk.adapters.keras import AimCallback as AimCallback +from aim.sdk.adapters.keras import AimTracker as AimTracker diff --git a/aim/keras_tuner.py b/aim/keras_tuner.py index 5f6577cae7..5d264e64d2 100644 --- a/aim/keras_tuner.py +++ b/aim/keras_tuner.py @@ -1,2 +1,2 @@ # Alias to SDK Keras-Tuner interface -from aim.sdk.adapters.keras_tuner import AimCallback # noqa F401 +from aim.sdk.adapters.keras_tuner import AimCallback as AimCallback diff --git a/aim/mxnet.py b/aim/mxnet.py index 403d33d40b..ceacfb118a 100644 --- a/aim/mxnet.py +++ b/aim/mxnet.py @@ -1,2 +1,2 @@ # Alias to SDK mxnet interface -from aim.sdk.adapters.mxnet import AimLoggingHandler # noqa F401 +from aim.sdk.adapters.mxnet import AimLoggingHandler as AimLoggingHandler diff --git a/aim/optuna.py b/aim/optuna.py index 5069d24695..28d0b1dbfe 100644 --- a/aim/optuna.py +++ b/aim/optuna.py @@ -1,2 +1,2 @@ # Alias to SDK Optuna interface -from aim.sdk.adapters.optuna import AimCallback # noqa F401 +from aim.sdk.adapters.optuna import AimCallback as AimCallback diff --git a/aim/paddle.py b/aim/paddle.py index 0c49486419..9069d936ac 100644 --- a/aim/paddle.py +++ b/aim/paddle.py @@ -1,2 +1,2 @@ # Alias to SDK PaddlePaddle interface -from aim.sdk.adapters.paddle import AimCallback # noqa F401 +from aim.sdk.adapters.paddle import AimCallback as AimCallback diff --git a/aim/prophet.py b/aim/prophet.py index 1a43316f44..661e95cd44 100644 --- a/aim/prophet.py +++ b/aim/prophet.py @@ -1,2 +1,2 @@ # Alias to SDK Prophet interface -from aim.sdk.adapters.prophet import AimLogger # noqa F401 +from aim.sdk.adapters.prophet import AimLogger as AimLogger diff --git a/aim/sdk/legacy/__init__.py b/aim/py.typed similarity index 100% rename from aim/sdk/legacy/__init__.py rename to aim/py.typed diff --git a/aim/pytorch.py b/aim/pytorch.py index c493b7a84d..677a68f88c 100644 --- a/aim/pytorch.py +++ b/aim/pytorch.py @@ -1,2 +1,3 @@ # Alias to SDK PyTorch utils -from aim.sdk.adapters.pytorch import track_params_dists, track_gradients_dists # noqa +from aim.sdk.adapters.pytorch import track_gradients_dists as track_gradients_dists +from aim.sdk.adapters.pytorch import track_params_dists as track_params_dists diff --git a/aim/pytorch_ignite.py b/aim/pytorch_ignite.py index 08cd67ce77..2189c6ddf2 100644 --- a/aim/pytorch_ignite.py +++ b/aim/pytorch_ignite.py @@ -1,2 +1,2 @@ # Alias to SDK PyTorch Ignite interface -from aim.sdk.adapters.pytorch_ignite import AimLogger # noqa F401 +from aim.sdk.adapters.pytorch_ignite import AimLogger as AimLogger diff --git a/aim/pytorch_lightning.py b/aim/pytorch_lightning.py index 50d10c1aae..b9a3405f9d 100644 --- a/aim/pytorch_lightning.py +++ b/aim/pytorch_lightning.py @@ -1,2 +1,2 @@ # Alias to SDK PyTorch Lightning interface -from aim.sdk.adapters.pytorch_lightning import AimLogger # noqa F401 +from aim.sdk.adapters.pytorch_lightning import AimLogger as AimLogger diff --git a/aim/sb3.py b/aim/sb3.py index 43fd7899eb..78bdec8ee9 100644 --- a/aim/sb3.py +++ b/aim/sb3.py @@ -1,2 +1,2 @@ # Alias to SDK sb3 interface -from aim.sdk.adapters.sb3 import AimCallback # noqa F401 +from aim.sdk.adapters.sb3 import AimCallback as AimCallback diff --git a/aim/sdk/__init__.py b/aim/sdk/__init__.py index 17d6974a67..f7c190da1a 100644 --- a/aim/sdk/__init__.py +++ b/aim/sdk/__init__.py @@ -1,10 +1,3 @@ -# Legacy SDK functions -from aim.sdk.legacy.flush import flush -from aim.sdk.legacy.init import init -from aim.sdk.legacy.select import select_metrics, select_runs -from aim.sdk.legacy.session import Session -from aim.sdk.legacy.track import set_params, track - # pre-defined sequences and custom objects from aim.sdk.objects import Audio, Distribution, Figure, Image, Text from aim.sdk.repo import Repo diff --git a/aim/sdk/adapters/distributed_hugging_face.py b/aim/sdk/adapters/distributed_hugging_face.py new file mode 100644 index 0000000000..561bef82dc --- /dev/null +++ b/aim/sdk/adapters/distributed_hugging_face.py @@ -0,0 +1,458 @@ +import os + +from aim.ext.resource.stat import Stat + + +try: + import accelerate.utils.environment +except ImportError: + raise RuntimeError( + 'This contrib module requires HuggingFace Accelerate to be installed. ' + 'Please install it with command: \n pip install accelerate' + ) + +import copy +import json +import logging +import select +import socket +import struct +import threading +import time +import typing + +import aim +import aim.ext.resource +import aim.hugging_face +import aim.sdk.configs + + +class IncompletePackageError(Exception): + pass + + +class IncompleteHeaderError(IncompletePackageError): + pass + + +class IncompleteDataError(IncompletePackageError): + pass + + +def packet_encode(usage: typing.Dict[str, typing.Any]) -> bytes: + data = json.dumps(usage) + header = len(data).to_bytes(4, 'big') + packet = b''.join((header, struct.pack(f'!{len(data)}s', data.encode('utf-8')))) + return packet + + +def packet_decode(packet: bytes) -> typing.Dict[str, typing.Any]: + length = int.from_bytes(packet[:4], 'big') + raw = struct.unpack_from(f'!{length}s', packet, 4)[0] + decoded = json.loads(raw) + return decoded + + +class ResourceTrackerForwarder(aim.ext.resource.ResourceTracker): + def _track(self, stat: Stat): + # Instead of tracking individual system metrics, forward the entire update to the MetricsReporter + # in turn, the MetricsReporter will create a packet ouf of Stat and its context (rank, world_size, etc). + # Next, it'll send that packet to the MetricsReceiver which will then push the data to the Aim server + self._tracker()(stat) + + +class MetricsReporter: + def __init__( + self, + host: str, + port: int, + node_rank: int, + rank: int, + interval: typing.Union[int, float], + ): + self.client: typing.Optional[socket.socket] = None + + self.node_rank = node_rank + self.rank = rank + self.log = logging.getLogger(f'MetricsReporter{rank}') + + self._connect(host=host, port=port) + self.tracker = ResourceTrackerForwarder(tracker=self, interval=interval, capture_logs=False) + + def start(self): + self.tracker.start() + + def stop(self): + if self.tracker._shutdown is False: + self.tracker.stop() + if self.client is not None: + self.client.close() + self.client = None + + def _connect( + self, + host: str, + port: int, + connection_timeout: int = 60 * 10, + retry_seconds: int = 5, + ): + start = time.time() + + while time.time() - start <= connection_timeout: + # This should deal with both ipv4 and ipv6 hosts + for family, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, proto=socket.SOL_TCP): + self.client = socket.socket(family, socktype, proto) + try: + self.client.connect(sa) + return + except (ConnectionRefusedError, OSError) as e: + self.client.close() + self.log.info( + f'Could not connect to main worker due to {e} - will retry in {retry_seconds} seconds' + ) + time.sleep(retry_seconds) + + raise ConnectionError(f'Could not connect to server {host}:{port} after {connection_timeout} seconds') + + def __call__(self, stat: aim.ext.resource.tracker.Stat): + if self.client is None: + self.log.info('Connection has already closed, will not propagate this system metrics snapshot') + return + + # This is invoked by @self.tracker + raw = { + 'stat': stat.stat_item.to_dict(), + 'worker': { + 'node_rank': self.node_rank, + 'rank': self.rank, + }, + } + self.log.debug(f'Send {raw}') + + packet = packet_encode(raw) + try: + self.client.sendall(packet) + except BrokenPipeError: + self.log.info( + f'BrokenPipeError while transmitting system metrics {raw} - will stop recording system metrics' + ) + try: + self.stop() + except RuntimeError as e: + if e.args[0] != 'cannot join current thread': + # Calling stop() causes self.tracker() to try to join this thread. In turn that raises + # this RuntimeError + raise + except Exception as e: + self.log.info(f'{e} while transmitting system metrics {raw} - will ignore exception') + + +class MetricsReceiver: + def __init__( + self, + host: str, + port: int, + num_workers: int, + connection_timeout: int, + ): + self.tracker: typing.Optional[ + typing.Callable[ + [ + typing.Dict[str, typing.Any], + typing.Dict[str, typing.Any], + ], + None, + ] + ] = None + + self.clients: typing.List[socket.socket] = [] + self.log = logging.getLogger('MetricsReceiver') + self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + self._wait_workers( + host=host, + port=port, + num_workers=num_workers, + connection_timeout=connection_timeout, + ) + + self.running = True + self.thread: typing.Optional[threading.Thread] = None + + def start( + self, + tracker: typing.Callable[ + [ + typing.Dict[str, typing.Any], + typing.Dict[str, typing.Any], + ], + None, + ], + ): + self.tracker = tracker + self.running = True + + self.thread = threading.Thread(target=self._collect_metrics, daemon=True) + self.thread.start() + + def stop(self): + if self.running: + self.running = False + self.thread.join() + + def _recv(self, sock: socket.socket, length: int) -> typing.Optional[bytes]: + data = b'' + retries = 0 + while len(data) < length and retries < 10: + buf = sock.recv(length - len(data)) + data += buf + retries += 1 + + if len(data) < length: + if len(data) > 0: + raise IncompletePackageError() + # If recv() returned b'' then the client disconnected + return None + + return data + + def _recv_packet(self, sock: socket.socket) -> typing.Optional[typing.Dict[str, typing.Any]]: + try: + header = self._recv(sock, 4) + + if header is None: + # The client disconnected + return None + + length = int.from_bytes(header, 'big') + except IncompletePackageError: + raise IncompleteHeaderError() + try: + data = self._recv(sock, length) + if len(data) > 0: + return json.loads(data) + except IncompletePackageError: + raise IncompleteDataError() + + def _collect_metrics(self): + while self.running: + read, _write, _error = select.select(self.clients, [], [], 5.0) + for client in typing.cast(typing.List[socket.socket], read): + try: + packet = self._recv_packet(client) + except IncompletePackageError as e: + self.log.info(f'Error {e} while receiving update - will assume this is a transient error') + continue + + if packet: + self.tracker(packet['stat'], packet['worker']) + else: + self.log.info('Client disconnected') + client.close() + self.clients.remove(client) + + def _wait_workers(self, host: str, port: int, num_workers: int, connection_timeout: float): + # This may raise an exception, don't catch it here and let it flow to the caller + self.server.bind((host, port)) + self.server.listen(num_workers) + + # We're actually going to pause here, till we get all clients OR we run out of time + + start = time.time() + self.log.info(f'Waiting for {num_workers} workers to connect') + # Block for 5 seconds while waiting for new connections + while time.time() - start <= connection_timeout: + read, _write, _error = select.select([self.server], [], [], 5.0) + for server in read: + client, _client_address = server.accept() + self.clients.append(client) + self.log.info(f'Client {len(self.clients)}/{num_workers} connected') + + if len(self.clients) == num_workers: + return + + self.server.close() + + raise ConnectionError(f'{num_workers - len(self.clients)} out of {num_workers} total clients did not connect') + + +class AimCallback(aim.hugging_face.AimCallback): + def __init__( + self, + main_port: int, + repo: typing.Optional[str] = None, + experiment: typing.Optional[str] = None, + system_tracking_interval: typing.Optional[int] = aim.ext.resource.DEFAULT_SYSTEM_TRACKING_INT, + log_system_params: typing.Optional[bool] = True, + capture_terminal_logs: typing.Optional[bool] = True, + main_addr: typing.Optional[str] = None, + distributed_information: typing.Optional[accelerate.utils.environment.CPUInformation] = None, + connection_timeout: int = 60 * 5, + workers_only_on_rank_0: bool = True, + ): + """A HuggingFace TrainerCallback which registers the system metrics of all workers involved in the training + under a single Aim run. + + This code initializes aim.hugging_face.AimCallback() only on rank 0 - otherwise we'd end up with multiple + Aim runs. + + Args: + main_port (:obj:`int`): Configures the port that the main worker will listen on. If this is None + then the code will raise an exception. + repo (:obj:`Union[Repo,str]`, optional): Aim repository path or Repo object to which Run object is bound. + If skipped, default Repo is used. + experiment (:obj:`str`, optional): Sets Run's `experiment` property. 'default' if not specified. + Can be used later to query runs/sequences. + system_tracking_interval (:obj:`int`, optional): Sets the tracking interval in seconds for system usage + metrics (CPU, Memory, etc.). Set to `None` to disable system metrics tracking. + log_system_params (:obj:`bool`, optional): Enable/Disable logging of system params such as installed + packages, git info, environment variables, etc. + main_addr (:obj:`str`, optional): The address of the main worker. If this is None then the code will + auto-discover it from the environment variable MASTER_ADDR. If this parameter cannot be resolved + to a non-empty value the method will raise an exception. + distributed_information (:obj:`str`, accelerate.utils.environment.CPUInformation): information about the + CPU in a distributed environment. If None, the code parses environment variables to auto create it. + See accelerate.utils.get_cpu_distributed_information() for more details + connection_timeout (:obj:`int`, optional): Maximum seconds to wait for the auxiliary workers to connect. + workers_only_on_rank_0 (:obj:`bool`): When set to true, only treat processes with local_rank 0 as + workers. Setting this to False, only makes sense when debugging the AimCallback() code. + + Raises: + ConnectionError: + If unable auxiliary workers are unable to connect to main worker + """ + if main_addr is None: + main_addr = os.environ.get('MASTER_ADDR') + + if not main_addr: + raise ValueError('main_addr cannot be empty') + + if not main_port or main_port < 0: + raise ValueError('main_port must be a positive number') + + if distributed_information is None: + distributed_information = accelerate.utils.get_cpu_distributed_information() + + self.distributed_information = distributed_information + self.connection_timeout = connection_timeout + + self.listening_socket: typing.Optional[socket.socket] = None + + self.metrics_reporter: typing.Optional[MetricsReporter] = None + self.metrics_receiver: typing.Optional[MetricsReceiver] = None + + self._run: typing.Optional[aim.Run] = None + self.log = logging.getLogger('CustomAimCallback') + + if not workers_only_on_rank_0: + # This is primarily for debugging. It enables the creation of multiple auxiliary workers on a single node + auxiliary_workers = self.distributed_information.world_size + else: + auxiliary_workers = self.distributed_information.world_size // self.distributed_information.local_world_size + + # Instantiate a MetricsReporter for all workers which are not rank 0 + if ( + self.distributed_information.rank is not None + and self.distributed_information.rank > 0 + and (not workers_only_on_rank_0 or self.distributed_information.local_rank == 0) + and system_tracking_interval is not None + ): + if workers_only_on_rank_0: + node_rank = distributed_information.rank // distributed_information.local_world_size + else: + node_rank = distributed_information.rank + + self.metrics_reporter = MetricsReporter( + host=main_addr, + port=main_port, + rank=self.distributed_information.rank, + node_rank=node_rank, + interval=system_tracking_interval, + ) + + self.log.info(f'Distributed worker {self.distributed_information.rank} connected') + elif self.distributed_information.rank == 0: + # When running as the main worker, we initialize aim as usual. If there're multiple + # auxiliary workers, we also start a listening server. The auxiliary workers will connect + # to this server and periodically send over their system metrics + super().__init__( + repo, + experiment, + system_tracking_interval, + log_system_params, + capture_terminal_logs, + ) + + if auxiliary_workers > 1 and main_port is not None and system_tracking_interval is not None: + self.log.info(f'There are {auxiliary_workers} workers') + + self.metrics_receiver = MetricsReceiver( + # Bind to 0.0.0.0 so that we can accept connections coming in from any interface + host='0.0.0.0', + port=main_port, + num_workers=auxiliary_workers - 1, + connection_timeout=self.connection_timeout, + ) + + self.metrics_receiver.start(self._push_auxiliary_worker_metrics) + + def _push_auxiliary_worker_metrics( + self, + stat: typing.Dict[str, typing.Any], + worker_info: typing.Dict[str, typing.Any], + ): + """Utility method which pushes the system metrics of an auxiliary worker to Aim + + Args: + stat: (:obj:`typing.Dict[str, typing.Any]`): A dictionary representation of + aim.ext.resource.stat.Stat + worker_info (:obj:`typing.Dict[str, typing.Any]`): A dictionary which represents the context of a + worker. For example, it can contain the fields {"rank": int, "node_rank": int} + """ + # TODO: Investigate whether it's better to spin up a dedicated RunTracker here or not + if self._run is None: + self.log.info(f'The aim Run is inactive, will not register these metrics from {worker_info}') + return + + tracker = self._run._tracker + context = copy.deepcopy(worker_info) + + for resource, usage in stat['system'].items(): + tracker( + usage, + name='{}{}'.format( + aim.ext.resource.configs.AIM_RESOURCE_METRIC_PREFIX, + resource, + ), + context=context, + ) + + # Store GPU stats + for gpu_idx, gpu in enumerate(stat['gpus']): + for resource, usage in gpu.items(): + context = copy.deepcopy(worker_info) + context.update({'gpu': gpu_idx}) + + tracker( + usage, + name='{}{}'.format( + aim.ext.resource.configs.AIM_RESOURCE_METRIC_PREFIX, + resource, + ), + context=context, + ) + + def on_train_begin(self, args, state, control, model=None, **kwargs): + super().on_train_begin(args, state, control, model, **kwargs) + + if self.metrics_reporter: + self.metrics_reporter.start() + + def close(self): + try: + super().close() + finally: + if self.metrics_receiver is not None: + self.metrics_receiver.stop() + if self.metrics_reporter is not None: + self.metrics_reporter.stop() diff --git a/aim/sdk/adapters/fastai.py b/aim/sdk/adapters/fastai.py index 37390444cc..88b7c4fdd5 100644 --- a/aim/sdk/adapters/fastai.py +++ b/aim/sdk/adapters/fastai.py @@ -11,7 +11,7 @@ from fastcore.basics import detuplify, ignore_exceptions, store_attr except ImportError: raise RuntimeError( - 'This contrib module requires fastai to be installed. ' 'Please install it with command: \n pip install fastai' + 'This contrib module requires fastai to be installed. Please install it with command: \n pip install fastai' ) logger = getLogger(__name__) @@ -107,7 +107,11 @@ def gather_args(self): args['n_inp'] = n_inp xb = self.dls.valid.one_batch()[:n_inp] args.update( - {f'input {n+1} dim {i+1}': d for n in range(n_inp) for i, d in enumerate(list(detuplify(xb[n]).shape))} + { + f'input {n + 1} dim {i + 1}': d + for n in range(n_inp) + for i, d in enumerate(list(detuplify(xb[n]).shape)) + } ) except Exception: logger.warning('Failed to gather input dimensions') diff --git a/aim/sdk/adapters/keras.py b/aim/sdk/adapters/keras.py index 4a2141249e..10af8b7118 100644 --- a/aim/sdk/adapters/keras.py +++ b/aim/sdk/adapters/keras.py @@ -9,7 +9,7 @@ from keras.callbacks import Callback except ImportError: raise RuntimeError( - 'This contrib module requires keras to be installed. ' 'Please install it with command: \n pip install keras' + 'This contrib module requires keras to be installed. Please install it with command: \n pip install keras' ) diff --git a/aim/sdk/adapters/lightgbm.py b/aim/sdk/adapters/lightgbm.py index f2bae4e164..f006cd9718 100644 --- a/aim/sdk/adapters/lightgbm.py +++ b/aim/sdk/adapters/lightgbm.py @@ -8,8 +8,7 @@ from lightgbm.callback import CallbackEnv except ImportError: raise RuntimeError( - 'This contrib module requires Lightgbm to be installed. ' - 'Please install it with command: \n pip install lightgbm' + 'This contrib module requires Lightgbm to be installed. Please install it with command: \n pip install lightgbm' ) diff --git a/aim/sdk/adapters/mxnet.py b/aim/sdk/adapters/mxnet.py index e10d4a19c2..88f005dd8d 100644 --- a/aim/sdk/adapters/mxnet.py +++ b/aim/sdk/adapters/mxnet.py @@ -75,7 +75,7 @@ def train_begin(self, estimator: Optional[Estimator], *args, **kwargs): optimizer = trainer.optimizer.__class__.__name__ lr = trainer.learning_rate - estimator.logger.info('Training begin: using optimizer %s ' 'with current learning rate %.4f ', optimizer, lr) + estimator.logger.info('Training begin: using optimizer %s with current learning rate %.4f ', optimizer, lr) if estimator.max_epoch: estimator.logger.info('Train for %d epochs.', estimator.max_epoch) else: diff --git a/aim/sdk/adapters/pytorch_ignite.py b/aim/sdk/adapters/pytorch_ignite.py index 42cf7d0f2e..6a9506c54d 100644 --- a/aim/sdk/adapters/pytorch_ignite.py +++ b/aim/sdk/adapters/pytorch_ignite.py @@ -8,7 +8,7 @@ from torch.optim import Optimizer except ImportError: raise RuntimeError( - 'This contrib module requires PyTorch to be installed. ' 'Please install it with command: \n pip install torch' + 'This contrib module requires PyTorch to be installed. Please install it with command: \n pip install torch' ) try: from ignite.contrib.handlers.base_logger import ( @@ -185,8 +185,7 @@ def __call__(self, engine: Engine, logger: AimLogger, event_name: Union[str, Eve if not isinstance(global_step, int): raise TypeError( - f'global_step must be int, got {type(global_step)}.' - ' Please check the output of global_step_transform.' + f'global_step must be int, got {type(global_step)}. Please check the output of global_step_transform.' ) metrics = {} diff --git a/aim/sdk/adapters/xgboost.py b/aim/sdk/adapters/xgboost.py index 8d99262875..832110f254 100644 --- a/aim/sdk/adapters/xgboost.py +++ b/aim/sdk/adapters/xgboost.py @@ -8,8 +8,7 @@ from xgboost.callback import TrainingCallback except ImportError: raise RuntimeError( - 'This contrib module requires XGBoost to be installed. ' - 'Please install it with command: \n pip install xgboost' + 'This contrib module requires XGBoost to be installed. Please install it with command: \n pip install xgboost' ) diff --git a/aim/sdk/base_run.py b/aim/sdk/base_run.py index 89edf63b02..f77c435d8b 100644 --- a/aim/sdk/base_run.py +++ b/aim/sdk/base_run.py @@ -39,6 +39,7 @@ def __init__( if self.read_only: assert run_hash is not None self.hash = run_hash + self.meta_tree: TreeView = self.repo.request_tree('meta', read_only=True).subtree('meta') else: if run_hash is None: self.hash = generate_run_hash() @@ -48,10 +49,8 @@ def __init__( raise MissingRunError(f'Cannot find Run {run_hash} in aim Repo {self.repo.path}.') self._lock = self.repo.request_run_lock(self.hash) self._lock.lock(force=force_resume) + self.meta_tree: TreeView = self.repo.request_tree('meta', self.hash, read_only=False).subtree('meta') - self.meta_tree: TreeView = self.repo.request_tree( - 'meta', self.hash, read_only=read_only, from_union=True - ).subtree('meta') self.meta_run_tree: TreeView = self.meta_tree.subtree('chunks').subtree(self.hash) self._series_run_trees: Dict[int, TreeView] = None diff --git a/aim/sdk/callbacks/caller.py b/aim/sdk/callbacks/caller.py index 6ac0c29aee..387406e225 100644 --- a/aim/sdk/callbacks/caller.py +++ b/aim/sdk/callbacks/caller.py @@ -42,7 +42,7 @@ def trigger(self, event_name: str, **kwargs): for handler in handlers: try: handler(**all_kwargs) - except Exception: # noqa + except Exception: # TODO catch errors on handler invocation (nice-to-have) logger.warning(f"Failed to run callback '{handler.__name__}'.") logger.warning(traceback.format_exc()) diff --git a/aim/sdk/data_version.py b/aim/sdk/data_version.py index 55f4f52d61..4c496ddb5e 100644 --- a/aim/sdk/data_version.py +++ b/aim/sdk/data_version.py @@ -1 +1 @@ -DATA_VERSION = (1, 3) +DATA_VERSION = (1, 4) diff --git a/aim/sdk/index_manager.py b/aim/sdk/index_manager.py index 7c26cb2bfe..166e6ae0e8 100644 --- a/aim/sdk/index_manager.py +++ b/aim/sdk/index_manager.py @@ -1,27 +1,83 @@ -import contextlib -import datetime +import hashlib import logging import os +import queue +import threading import time from pathlib import Path -from threading import Thread -from typing import Iterable +from typing import Dict import aimrocks.errors -import pytz from aim.sdk.repo import Repo -from aim.sdk.run_status_watcher import Event -from aim.storage.locking import RefreshLock +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer +from watchdog.observers.api import ObservedWatch +from watchdog.observers.polling import PollingObserver logger = logging.getLogger(__name__) +class NewChunkCreatedHandler(FileSystemEventHandler): + def __init__(self, manager): + self.manager = manager + self.known_chunks = set(p.name for p in self.manager.chunks_dir.iterdir() if p.is_dir()) + + def on_modified(self, event): + if event.is_directory and Path(event.src_path) == self.manager.chunks_dir: + current_chunks = set(p.name for p in self.manager.chunks_dir.iterdir() if p.is_dir()) + new_chunks = current_chunks - self.known_chunks + for chunk_name in new_chunks: + chunk_path = self.manager.chunks_dir / chunk_name + logger.debug(f'Detected new chunk directory: {chunk_name}') + self.manager.monitor_chunk_directory(chunk_path) + self.known_chunks = current_chunks + + +class ChunkChangedHandler(FileSystemEventHandler): + def __init__(self, manager): + self.manager = manager + self.pending_events = set() + self.lock = threading.Lock() + + def _trigger_event(self, run_hash): + with self.lock: + if run_hash not in self.pending_events: + self.pending_events.add(run_hash) + threading.Timer(0.5, self._process_event, [run_hash]).start() + + def _process_event(self, run_hash): + with self.lock: + if run_hash in self.pending_events: + self.pending_events.remove(run_hash) + logger.debug(f'Triggering indexing for run {run_hash}') + self.manager.add_run_to_queue(run_hash) + + def on_any_event(self, event): + if event.is_directory: + return + + event_path = Path(event.src_path) + parent_dir = event_path.parent + run_hash = parent_dir.name + + # Ensure the parent directory is directly inside meta/chunks/ + if parent_dir.parent != self.manager.chunks_dir: + logger.debug(f'Skipping event outside valid chunk directory: {event.src_path}') + return + + if event_path.name.startswith('LOG'): + logger.debug(f'Skipping event for LOG-prefixed file: {event.src_path}') + return + + logger.debug(f'Detected change in {event.src_path}') + self._trigger_event(run_hash) + + class RepoIndexManager: index_manager_pool = {} - INDEXING_GRACE_PERIOD = 10 @classmethod def get_index_manager(cls, repo: Repo): @@ -34,158 +90,133 @@ def get_index_manager(cls, repo: Repo): def __init__(self, repo: Repo): self.repo_path = repo.path self.repo = repo - self.progress_dir = Path(self.repo_path) / 'meta' / 'progress' - self.progress_dir.mkdir(parents=True, exist_ok=True) + self.chunks_dir = Path(self.repo_path) / 'meta' / 'chunks' + self.chunks_dir.mkdir(parents=True, exist_ok=True) - self.heartbeat_dir = Path(self.repo_path) / 'check_ins' - self.run_heartbeat_cache = {} - - self._indexing_in_progress = False - self._reindex_thread: Thread = None self._corrupted_runs = set() - @property - def repo_status(self): - if self._indexing_in_progress is True: - return 'indexing in progress' - if self.reindex_needed: - return 'needs indexing' - return 'up-to-date' - - @property - def reindex_needed(self) -> bool: - runs_with_progress = os.listdir(self.progress_dir) - return len(runs_with_progress) > 0 - - def start_indexing_thread(self): - logger.info(f"Starting indexing thread for repo '{self.repo_path}'") - self._reindex_thread = Thread(target=self._run_forever, daemon=True) - self._reindex_thread.start() - - def _run_forever(self): - idle_cycles = 0 - while True: - self._indexing_in_progress = False - for run_hash in self._next_stalled_run(): - logger.info(f'Found un-indexed run {run_hash}. Indexing...') - self._indexing_in_progress = True - idle_cycles = 0 - self.index(run_hash) - - # sleep for small interval to release index db lock in between and allow - # other running jobs to properly finalize and index Run. - sleep_interval = 0.1 - time.sleep(sleep_interval) - if not self._indexing_in_progress: - idle_cycles += 1 - sleep_interval = 2 * idle_cycles if idle_cycles < 5 else 10 - logger.info( - f'No un-indexed runs found. Next check will run in {sleep_interval} seconds. ' - f'Waiting for un-indexed run...' - ) - time.sleep(sleep_interval) - - def _runs_with_progress(self) -> Iterable[str]: - runs_with_progress = filter(lambda x: x not in self._corrupted_runs, os.listdir(self.progress_dir)) - run_hashes = sorted(runs_with_progress, key=lambda r: os.path.getmtime(os.path.join(self.progress_dir, r))) - return run_hashes - - def _next_stalled_run(self): - for run_hash in self._runs_with_progress(): - if self._is_run_stalled(run_hash): - yield run_hash - - def _is_run_stalled(self, run_hash: str) -> bool: - stalled = False - heartbeat_files = list(sorted(self.heartbeat_dir.glob(f'{run_hash}-*-progress-*-*'), reverse=True)) - if heartbeat_files: - last_heartbeat = Event(heartbeat_files[0].name) - last_recorded_heartbeat = self.run_heartbeat_cache.get(run_hash) - if last_recorded_heartbeat is None: - self.run_heartbeat_cache[run_hash] = last_heartbeat - elif last_heartbeat.idx > last_recorded_heartbeat.idx: - self.run_heartbeat_cache[run_hash] = last_heartbeat - else: - time_passed = time.time() - last_recorded_heartbeat.detected_epoch_time - if last_recorded_heartbeat.next_event_in + RepoIndexManager.INDEXING_GRACE_PERIOD < time_passed: - stalled = True + self.indexing_queue = queue.PriorityQueue() + self.lock = threading.Lock() + + self.new_chunk_observer = Observer() + self.chunk_change_observer = PollingObserver() + + self.new_chunk_handler = NewChunkCreatedHandler(self) + self.chunk_change_handler = ChunkChangedHandler(self) + self._watches: Dict[str, ObservedWatch] = dict() + self.new_chunk_observer.schedule(self.new_chunk_handler, self.chunks_dir, recursive=False) + + self._stop_event = threading.Event() + self._index_thread = None + self._monitor_thread = None + + def start(self): + self._stop_event.clear() + self.new_chunk_observer.start() + self.chunk_change_observer.start() + + if not self._index_thread or not self._index_thread.is_alive(): + self._index_thread = threading.Thread(target=self._process_indexing_queue, daemon=True) + self._index_thread.start() + + if not self._monitor_thread or not self._monitor_thread.is_alive(): + self._monitor_thread = threading.Thread(target=self._monitor_existing_chunks, daemon=True) + self._monitor_thread.start() + + def stop(self): + self._stop_event.set() + self.new_chunk_observer.stop() + self.chunk_change_observer.stop() + if self._monitor_thread: + self._monitor_thread.join() + if self._index_thread: + self._index_thread.join() + + def _monitor_existing_chunks(self): + while not self._stop_event.is_set(): + index_db = self.repo.request_tree('meta', read_only=True) + monitored_chunks = set(self._watches.keys()) + for chunk_path in self.chunks_dir.iterdir(): + if ( + chunk_path.is_dir() + and chunk_path.name not in monitored_chunks + and self._is_run_index_outdated(chunk_path.name, index_db) + ): + logger.debug(f'Monitoring existing chunk: {chunk_path}') + self.monitor_chunk_directory(chunk_path) + logger.debug(f'Triggering indexing for run {chunk_path.name}') + self.add_run_to_queue(chunk_path.name) + self.repo.container_pool.clear() + time.sleep(5) + + def _stop_monitoring_chunk(self, run_hash): + watch = self._watches.pop(run_hash, None) + if watch: + self.chunk_change_observer.unschedule(watch) + logger.debug(f'Stopped monitoring chunk: {run_hash}') + + def monitor_chunk_directory(self, chunk_path): + """Ensure chunk directory is monitored using a single handler.""" + if chunk_path.name not in self._watches: + watch = self.chunk_change_observer.schedule(self.chunk_change_handler, chunk_path, recursive=True) + self._watches[chunk_path.name] = watch + logger.debug(f'Started monitoring chunk directory: {chunk_path}') else: - stalled = True - return stalled - - def _index_lock_path(self): - return Path(self.repo.path) / 'locks' / 'index' - - @contextlib.contextmanager - def lock_index(self, lock: RefreshLock): + logger.debug(f'Chunk directory already monitored: {chunk_path}') + + def add_run_to_queue(self, run_hash): + if run_hash in self._corrupted_runs: + return + timestamp = os.path.getmtime(os.path.join(self.chunks_dir, run_hash)) + with self.lock: + self.indexing_queue.put((timestamp, run_hash)) + logger.debug(f'Run {run_hash} added to indexing queue with timestamp {timestamp}') + + def _process_indexing_queue(self): + while not self._stop_event.is_set(): + _, run_hash = self.indexing_queue.get() + logger.debug(f'Indexing run {run_hash}...') + self.index(run_hash) + self.indexing_queue.task_done() + + def index(self, run_hash): + index = self.repo._get_index_tree('meta', 0).view(()) try: - self._safe_acquire_lock(lock) - yield - finally: - lock.release() - - def _safe_acquire_lock(self, lock: RefreshLock): - last_touch_seen = None - prev_touch_time = None - last_owner_id = None - while True: - try: - lock.acquire() - logger.debug('Lock is acquired!') - break - except TimeoutError: - owner_id = lock.owner_id() - if owner_id != last_owner_id: - logger.debug(f'Lock has been acquired by {owner_id}') - last_owner_id = owner_id - prev_touch_time = None - else: # same holder as from prev. iteration - last_touch_time = lock.last_refresh_time() - if last_touch_time != prev_touch_time: - prev_touch_time = last_touch_time - last_touch_seen = time.time() - logger.debug(f'Lock has been refreshed. Touch time: {last_touch_time}') - continue - assert last_touch_seen is not None - if time.time() - last_touch_seen > RefreshLock.GRACE_PERIOD: - logger.debug('Grace period exceeded. Force-acquiring the lock.') - with lock.meta_lock(): - # double check holder ID - if lock.owner_id() != last_owner_id: # someone else grabbed lock - continue - else: - lock.force_release() - try: - lock.acquire() - logger.debug('lock has been forcefully acquired!') - break - except TimeoutError: - continue - else: - logger.debug( - f'Countdown to force-acquire lock. ' - f'Time remaining: {RefreshLock.GRACE_PERIOD - (time.time() - last_touch_seen)}' - ) - - def run_needs_indexing(self, run_hash: str) -> bool: - return os.path.exists(self.progress_dir / run_hash) - - def index( - self, - run_hash, - ) -> bool: - lock = RefreshLock(self._index_lock_path(), timeout=10) - with self.lock_index(lock): - index = self.repo._get_index_tree('meta', 0).view(()) - try: - meta_tree = self.repo.request_tree( - 'meta', run_hash, read_only=True, from_union=False, no_cache=True - ).subtree('meta') - meta_run_tree = meta_tree.subtree('chunks').subtree(run_hash) - meta_run_tree.finalize(index=index) - if meta_run_tree.get('end_time') is None: - index['meta', 'chunks', run_hash, 'end_time'] = datetime.datetime.now(pytz.utc).timestamp() - except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): - logger.warning(f"Indexing thread detected corrupted run '{run_hash}'. Skipping.") - self._corrupted_runs.add(run_hash) - return True + run_checksum = self._get_run_checksum(run_hash) + meta_tree = self.repo.request_tree('meta', run_hash, read_only=True, skip_read_optimization=True).subtree( + 'meta' + ) + meta_run_tree = meta_tree.subtree('chunks').subtree(run_hash) + meta_run_tree.finalize(index=index) + index['index_cache', run_hash] = run_checksum + + if meta_run_tree.get('end_time') is not None: + logger.debug(f'Indexing thread detected finished run: {run_hash}. Stopping monitoring...') + self._stop_monitoring_chunk(run_hash) + + except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): + logger.warning(f'Indexing thread detected corrupted run: {run_hash}. Skipping.') + self._corrupted_runs.add(run_hash) + return True + + def _is_run_index_outdated(self, run_hash, index_db): + return self._get_run_checksum(run_hash) != index_db.get(('index_cache', run_hash)) + + def _get_run_checksum(self, run_hash): + hash_obj = hashlib.md5() + + for root, dirs, files in os.walk(os.path.join(self.chunks_dir, run_hash)): + for name in sorted(files): # sort to ensure consistent order + if name.startswith('LOG'): # skip access logs + continue + filepath = os.path.join(root, name) + try: + stat = os.stat(filepath) + hash_obj.update(filepath.encode('utf-8')) + hash_obj.update(str(stat.st_mtime).encode('utf-8')) + hash_obj.update(str(stat.st_size).encode('utf-8')) + except FileNotFoundError: + # File might have been deleted between os.walk and os.stat + continue + + return hash_obj.hexdigest() diff --git a/aim/sdk/legacy/deprecation_warning.py b/aim/sdk/legacy/deprecation_warning.py deleted file mode 100644 index 36047509ee..0000000000 --- a/aim/sdk/legacy/deprecation_warning.py +++ /dev/null @@ -1,15 +0,0 @@ -import logging - -from functools import wraps - - -logger = logging.getLogger(__name__) - - -def deprecated(func): - @wraps(func) - def wrapper(*args, **kwargs): - logger.warning(msg=f'Usage of {func.__qualname__} is deprecated!') - return func(*args, **kwargs) - - return wrapper diff --git a/aim/sdk/legacy/flush.py b/aim/sdk/legacy/flush.py deleted file mode 100644 index 74dc48c2ae..0000000000 --- a/aim/sdk/legacy/flush.py +++ /dev/null @@ -1,6 +0,0 @@ -from aim.sdk.legacy.deprecation_warning import deprecated - - -@deprecated -def flush(): - pass diff --git a/aim/sdk/legacy/init.py b/aim/sdk/legacy/init.py deleted file mode 100644 index 6456e49506..0000000000 --- a/aim/sdk/legacy/init.py +++ /dev/null @@ -1,7 +0,0 @@ -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.legacy.session import DefaultSession - - -@deprecated -def init(*args, **kwargs): - DefaultSession(*args, **kwargs) diff --git a/aim/sdk/legacy/select.py b/aim/sdk/legacy/select.py deleted file mode 100644 index 4b637d13fa..0000000000 --- a/aim/sdk/legacy/select.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Optional - -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.repo import Repo - - -@deprecated -def select_metrics(search_statement: str, repo_path: Optional[str] = None): - if repo_path is not None: - repo = Repo.from_path(repo_path) - else: - repo = Repo.default_repo() - - if not repo: - return None - - return repo.query_metrics(search_statement) - - -@deprecated -def select_runs(expression: Optional[str] = None, repo_path: Optional[str] = None): - if repo_path is not None: - repo = Repo.from_path(repo_path) - else: - repo = Repo.default_repo() - - if not repo: - return None - - return repo.query_runs(expression) diff --git a/aim/sdk/legacy/session/__init__.py b/aim/sdk/legacy/session/__init__.py deleted file mode 100644 index 6c268677a9..0000000000 --- a/aim/sdk/legacy/session/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from aim.sdk.legacy.session.session import DefaultSession, Session diff --git a/aim/sdk/legacy/session/configs.py b/aim/sdk/legacy/session/configs.py deleted file mode 100644 index 08db0117c4..0000000000 --- a/aim/sdk/legacy/session/configs.py +++ /dev/null @@ -1 +0,0 @@ -DEFAULT_FLUSH_FREQUENCY = 128 diff --git a/aim/sdk/legacy/session/session.py b/aim/sdk/legacy/session/session.py deleted file mode 100644 index cc77b97449..0000000000 --- a/aim/sdk/legacy/session/session.py +++ /dev/null @@ -1,132 +0,0 @@ -import atexit -import os -import signal -import threading - -from typing import Optional - -from aim.ext.exception_resistant import exception_resistant -from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.repo import Repo -from aim.sdk.run import Run - - -class Session: - sessions = {} - - _are_exit_listeners_set = False - _original_sigint_handler = None - _original_sigterm_handler = None - - @deprecated - def __init__( - self, - repo: Optional[str] = None, - experiment: Optional[str] = None, - flush_frequency: int = 0, # unused - block_termination: bool = True, # unused - run: Optional[str] = None, - system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, - ): - self._repo = Repo.from_path(repo) if repo else Repo.default_repo() - self._repo_path = self._repo.path - self._run = Run(run, repo=self._repo, experiment=experiment, system_tracking_interval=system_tracking_interval) - self._run_hash = self._run.hash - self.active = True - - Session.sessions.setdefault(self._repo_path, []) - Session.sessions[self._repo_path].append(self) - - # Bind signal listeners - self._set_exit_handlers() - - @property - def run_hash(self): - return self._run_hash - - @property - def repo_path(self): - return self._repo_path - - @exception_resistant(silent=False) - def track(self, *args, **kwargs): - val = args[0] - name = kwargs.pop('name') - step = kwargs.pop('step', None) - epoch = kwargs.pop('epoch', None) - for key in kwargs.keys(): - if key.startswith('__'): - del kwargs[key] - - self._run.track(val, name=name, step=step, epoch=epoch, context=kwargs) - - @exception_resistant(silent=False) - def set_params(self, params: dict, name: Optional[str] = None): - if name is None: - self._run[...] = params - else: - self._run[name] = params - - def flush(self): - pass - - @exception_resistant(silent=False) - def close(self): - if not self.active: - raise Exception('session is closed') - if self._run: - del self._run - self._run = None - if self._repo_path in Session.sessions and self in Session.sessions[self._repo_path]: - Session.sessions[self._repo_path].remove(self) - if len(Session.sessions[self._repo_path]) == 0: - del Session.sessions[self._repo_path] - self.active = False - - @classmethod - def _close_sessions(cls, *args, **kwargs): - threads = [] - for _, sessions in cls.sessions.items(): - for session in sessions: - th = threading.Thread(target=session.close) - th.daemon = True - threads.append(th) - - for th in threads: - th.start() - - for th in threads: - th.join() - - if len(args): - if args[0] == 15: - signal.signal(signal.SIGTERM, cls._original_sigterm_handler) - os.kill(os.getpid(), 15) - # elif args[0] == 2: - # signal.signal(signal.SIGINT, cls._original_sigint_handler) - # os.kill(os.getpid(), 2) - - @classmethod - def _set_exit_handlers(cls): - if not cls._are_exit_listeners_set: - cls._are_exit_listeners_set = True - # cls._original_sigint_handler = signal.getsignal(signal.SIGINT) - cls._original_sigterm_handler = signal.getsignal(signal.SIGTERM) - - atexit.register(cls._close_sessions) - # signal.signal(signal.SIGINT, cls._close_sessions) - signal.signal(signal.SIGTERM, cls._close_sessions) - - -DefaultSession = Session - - -def get_default_session() -> Session: - if len(Session.sessions.keys()) > 0: - default_sess_key = list(Session.sessions.keys())[0] - if len(Session.sessions[default_sess_key]) > 0: - return Session.sessions[default_sess_key][0] - - # Create and return default session otherwise - return DefaultSession() diff --git a/aim/sdk/legacy/track.py b/aim/sdk/legacy/track.py deleted file mode 100644 index ef04f2bcc7..0000000000 --- a/aim/sdk/legacy/track.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Optional - -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.legacy.session.session import get_default_session - - -@deprecated -def track(*args, **kwargs): - sess = get_default_session() - return sess.track(*args, **kwargs) - - -@deprecated -def set_params(params: dict, name: Optional[str] = None): - sess = get_default_session() - return sess.set_params(params, name) diff --git a/aim/sdk/objects/io/wavfile.py b/aim/sdk/objects/io/wavfile.py index 5c58daf4a0..34d187c7a2 100644 --- a/aim/sdk/objects/io/wavfile.py +++ b/aim/sdk/objects/io/wavfile.py @@ -316,7 +316,7 @@ def _raise_bad_format(format_tag): except ValueError: format_name = f'{format_tag:#06x}' raise ValueError( - f"Unknown wave file format: {format_name}. Supported formats: {', '.join(x.name for x in KNOWN_WAVE_FORMATS)}" + f'Unknown wave file format: {format_name}. Supported formats: {", ".join(x.name for x in KNOWN_WAVE_FORMATS)}' ) @@ -447,12 +447,12 @@ def _read_data_chunk(fid, format_tag, channels, bit_depth, is_big_endian, block_ # Remaining bit depths can map directly to signed numpy dtypes dtype = f'{fmt}i{bytes_per_sample}' else: - raise ValueError('Unsupported bit depth: the WAV file ' f'has {bit_depth}-bit integer data.') + raise ValueError(f'Unsupported bit depth: the WAV file has {bit_depth}-bit integer data.') elif format_tag == WAVE_FORMAT.IEEE_FLOAT: if bit_depth in {32, 64}: dtype = f'{fmt}f{bytes_per_sample}' else: - raise ValueError('Unsupported bit depth: the WAV file ' f'has {bit_depth}-bit floating-point data.') + raise ValueError(f'Unsupported bit depth: the WAV file has {bit_depth}-bit floating-point data.') else: _raise_bad_format(format_tag) @@ -480,7 +480,7 @@ def _read_data_chunk(fid, format_tag, channels, bit_depth, is_big_endian, block_ data = numpy.memmap(fid, dtype=dtype, mode='c', offset=start, shape=(n_samples,)) fid.seek(start + size) else: - raise ValueError('mmap=True not compatible with ' f'{bytes_per_sample}-byte container size.') + raise ValueError(f'mmap=True not compatible with {bytes_per_sample}-byte container size.') _handle_pad_byte(fid, size) @@ -516,7 +516,7 @@ def _read_riff_chunk(fid): fmt = '>I' else: # There are also .wav files with "FFIR" or "XFIR" signatures? - raise ValueError(f'File format {repr(str1)} not understood. Only ' "'RIFF' and 'RIFX' supported.") + raise ValueError(f"File format {repr(str1)} not understood. Only 'RIFF' and 'RIFX' supported.") # Size of entire file file_size = struct.unpack(fmt, fid.read(4))[0] + 8 @@ -554,7 +554,7 @@ def read(buffer, mmap=False): if data_chunk_received: # End of file but data successfully read warnings.warn( - 'Reached EOF prematurely; finished at {:d} bytes, ' 'expected {:d} bytes from header.'.format( + 'Reached EOF prematurely; finished at {:d} bytes, expected {:d} bytes from header.'.format( fid.tell(), file_size ), WavFileWarning, diff --git a/aim/sdk/query_analyzer.py b/aim/sdk/query_analyzer.py new file mode 100644 index 0000000000..f06d589f02 --- /dev/null +++ b/aim/sdk/query_analyzer.py @@ -0,0 +1,165 @@ +import ast +import sys + +from typing import Any, List, Tuple + + +class Unknown(ast.AST): + pass + + +Unknown = Unknown() # create a single instance of value node + +if sys.version_info.minor < 9: + import astunparse + + def unparse(*args, **kwargs): + return astunparse.unparse(*args, **kwargs) +else: + + def unparse(*args, **kwargs): + return ast.unparse(*args, **kwargs) + + +class QueryExpressionTransformer(ast.NodeTransformer): + def __init__(self, *, var_names: List[str]): + self._var_names = var_names + + def transform(self, expr: str) -> Tuple[str, bool]: + if expr: + node = ast.parse(expr, mode='eval') + transformed = self.visit(node) + if transformed is Unknown: + return expr, False + else: + return unparse(transformed), True + else: + return expr, False + + def visit_Expression(self, node: ast.Expression) -> Any: + node: ast.Expression = self.generic_visit(node) + if node.body is Unknown: + return Unknown + return node + + def visit_Expr(self, node: ast.Expr) -> Any: + node: ast.Expr = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Constant(self, node: ast.Constant) -> Any: + return node + + def visit_JoinedStr(self, node: ast.JoinedStr) -> Any: + node: ast.JoinedStr = self.generic_visit(node) + for val in node.values: + if val is Unknown: + return Unknown + return node + + def visit_FormattedValue(self, node: ast.FormattedValue) -> Any: + node: ast.FormattedValue = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in self._var_names: + return Unknown + else: + return node + + def visit_Compare(self, node: ast.Compare) -> Any: + node: ast.Compare = self.generic_visit(node) + if node.left is Unknown: + return Unknown + for comp in node.comparators: + if comp is Unknown: + return Unknown + return node + + def visit_List(self, node: ast.List) -> Any: + node: ast.List = self.generic_visit(node) + for sub in node.elts: + if sub is Unknown: + return Unknown + return node + + def visit_Tuple(self, node: ast.Tuple) -> Any: + node: ast.Tuple = self.generic_visit(node) + for sub in node.elts: + if sub is Unknown: + return Unknown + return node + + def visit_Dict(self, node: ast.Dict) -> Any: + node: ast.Dict = self.generic_visit(node) + for key in node.keys: + if key is Unknown: + return Unknown + for val in node.values: + if val is Unknown: + return Unknown + return node + + def visit_BoolOp(self, node: ast.BoolOp) -> Any: + node: ast.BoolOp = self.generic_visit(node) + node_values = list(filter(lambda x: x is not Unknown, node.values)) + if isinstance(node.op, ast.And): + if len(node_values) == 1: + return node_values[0] + elif len(node_values) == 0: + return Unknown + else: + if len(node_values) < len(node.values): + return Unknown + return ast.BoolOp(op=node.op, values=node_values) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: + node: ast.UnaryOp = self.generic_visit(node) + if node.operand is Unknown: + return Unknown + return node + + def visit_BinOp(self, node: ast.BinOp) -> Any: + node: ast.BinOp = self.generic_visit(node) + if node.left is Unknown or node.right is Unknown: + return Unknown + return node + + def visit_IfExp(self, node: ast.IfExp) -> Any: + node: ast.IfExp = self.generic_visit(node) + if node.test is Unknown or node.body is Unknown or node.orelse is Unknown: + return Unknown + return node + + def visit_Attribute(self, node: ast.Attribute) -> Any: + node: ast.Attribute = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Call(self, node: ast.Call) -> Any: + node: ast.Call = self.generic_visit(node) + if node.func is Unknown: + return Unknown + for arg in node.args: + if arg is Unknown: + return Unknown + for kwarg in node.keywords: + if kwarg is Unknown: + return Unknown + return node + + def visit_Subscript(self, node: ast.Subscript) -> Any: + node: ast.Subscript = self.generic_visit(node) + if node.value is Unknown or node.slice is Unknown: + return Unknown + return node + + def visit_Slice(self, node: ast.Slice) -> Any: + node: ast.Slice = self.generic_visit(node) + if node.lower is Unknown or node.upper is Unknown or node.step is Unknown: + return Unknown + return node diff --git a/aim/sdk/repo.py b/aim/sdk/repo.py index 9927949649..1ffef1c9b4 100644 --- a/aim/sdk/repo.py +++ b/aim/sdk/repo.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple from weakref import WeakValueDictionary +import aimrocks.errors + from aim.ext.cleanup import AutoClean from aim.ext.sshfs.utils import mount_remote_repo, unmount_remote_repo from aim.ext.task_queue.queue import TaskQueue @@ -127,9 +129,14 @@ def __init__(self, path: str, *, read_only: Optional[bool] = None, init: Optiona self.root_path = path self.path = os.path.join(self.root_path, get_aim_repo_name()) - if init: + if init and not self.is_remote_repo: os.makedirs(self.path, exist_ok=True) os.makedirs(os.path.join(self.path, 'locks'), exist_ok=True) + + # Make sure meta index db is created + path = os.path.join(self.path, 'meta', 'index') + RocksContainer(path, read_only=False) + if not self.is_remote_repo and not os.path.exists(self.path): if self._mount_root: unmount_remote_repo(self.root_path, self._mount_root) @@ -137,7 +144,6 @@ def __init__(self, path: str, *, read_only: Optional[bool] = None, init: Optiona self.container_pool: Dict[ContainerConfig, Container] = WeakValueDictionary() self.persistent_pool: Dict[ContainerConfig, Container] = dict() - self.container_view_pool: Dict[ContainerConfig, Container] = WeakValueDictionary() self._run_props_cache_hint = None self._encryption_key = None @@ -160,7 +166,7 @@ def __init__(self, path: str, *, read_only: Optional[bool] = None, init: Optiona @property def meta_tree(self): - return self.request_tree('meta', read_only=True, from_union=True).subtree('meta') + return self.request_tree('meta', read_only=True).subtree('meta') def __repr__(self) -> str: return f'' @@ -269,23 +275,6 @@ def get_version(cls, path: str): def is_remote_path(cls, path: str): return path.startswith('aim://') - def _get_container(self, name: str, read_only: bool, from_union: bool = False) -> Container: - if self.read_only and not read_only: - raise ValueError('Repo is read-only') - - container_config = ContainerConfig(name, None, read_only=read_only) - container = self.container_pool.get(container_config) - if container is None: - path = os.path.join(self.path, name) - if from_union: - container = RocksUnionContainer(path, read_only=read_only) - self.persistent_pool[container_config] = container - else: - container = RocksContainer(path, read_only=read_only) - self.container_pool[container_config] = container - - return container - def _get_index_tree(self, name: str, timeout: int): if not self.is_remote_repo: return self._get_index_container(name, timeout).tree() @@ -306,49 +295,30 @@ def _get_index_container(self, name: str, timeout: int) -> Container: return container - def request_tree( - self, - name: str, - sub: str = None, - *, - read_only: bool, - from_union: bool = False, # TODO maybe = True by default - no_cache: bool = False, - ): + def request_tree(self, name: str, sub: str = None, *, read_only: bool, skip_read_optimization: bool = False): if not self.is_remote_repo: - return self.request(name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache).tree() + return self.request_container( + name, sub, read_only=read_only, skip_read_optimization=skip_read_optimization + ).tree() else: - return ProxyTree(self._client, name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache) + return ProxyTree(self._client, name, sub, read_only=read_only) - def request( - self, - name: str, - sub: str = None, - *, - read_only: bool, - from_union: bool = False, # TODO maybe = True by default - no_cache: bool = False, - ): + def request_container(self, name: str, sub: str = None, *, read_only: bool, skip_read_optimization: bool = False): container_config = ContainerConfig(name, sub, read_only) - container_view = self.container_view_pool.get(container_config) - if container_view is None or no_cache: - if read_only: - if from_union: - path = name - else: - assert sub is not None - path = os.path.join(name, 'chunks', sub) - container = self._get_container(path, read_only=True, from_union=from_union) + container = self.container_pool.get(container_config) + if container is None: + if sub is None: + try: + path = os.path.join(self.path, name, 'index') + container = RocksContainer(path, read_only=True, skip_read_optimization=skip_read_optimization) + except aimrocks.errors.RocksIOError: + path = os.path.join(self.path, name) + container = RocksUnionContainer(path, read_only=True) else: - assert sub is not None - path = os.path.join(name, 'chunks', sub) - container = self._get_container(path, read_only=False, from_union=False) - - container_view = container - if not no_cache: - self.container_view_pool[container_config] = container_view - - return container_view + path = os.path.join(self.path, name, 'chunks', sub) + container = RocksContainer(path, read_only=read_only, skip_read_optimization=skip_read_optimization) + self.container_pool[container_config] = container + return container def request_props(self, hash_: str, read_only: bool, created_at: 'datetime' = None): if self.is_remote_repo: @@ -739,9 +709,6 @@ def encryption_key(self): return encryption_key - def _get_meta_tree(self): - return self.request_tree('meta', read_only=True, from_union=True).subtree('meta') - @staticmethod def available_sequence_types(): return Sequence.registry.keys() @@ -763,7 +730,6 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di Returns: :obj:`dict`: Tree of sequences and their contexts groupped by sequence type. """ - meta_tree = self._get_meta_tree() sequence_traces = {} if isinstance(sequence_types, str): sequence_types = (sequence_types,) @@ -776,7 +742,7 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di dtype_traces = set() for dtype in dtypes: try: - dtype_trace_tree = meta_tree.collect(('traces_types', dtype)) + dtype_trace_tree = self.meta_tree.collect(('traces_types', dtype)) for ctx_id, seqs in dtype_trace_tree.items(): for seq_name in seqs.keys(): dtype_traces.add((ctx_id, seq_name)) @@ -784,7 +750,7 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di pass if 'float' in dtypes: # old sequences without dtype set are considered float sequences try: - dtype_trace_tree = meta_tree.collect('traces') + dtype_trace_tree = self.meta_tree.collect('traces') for ctx_id, seqs in dtype_trace_tree.items(): for seq_name in seqs.keys(): dtype_traces.add((ctx_id, seq_name)) @@ -792,7 +758,7 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di pass traces_info = defaultdict(list) for ctx_id, seq_name in dtype_traces: - traces_info[seq_name].append(meta_tree['contexts', ctx_id]) + traces_info[seq_name].append(self.meta_tree['contexts', ctx_id]) sequence_traces[seq_type] = traces_info return sequence_traces @@ -802,9 +768,8 @@ def collect_params_info(self) -> dict: Returns: :obj:`dict`: All runs meta-parameters. """ - meta_tree = self._get_meta_tree() try: - return meta_tree.collect('attrs', strict=False) + return self.meta_tree.collect('attrs', strict=False) except KeyError: return {} @@ -875,22 +840,13 @@ def _delete_run(self, run_hash): def _copy_run(self, run_hash, dest_repo): def copy_trees(): # copy run meta tree - source_meta_tree = self.request_tree( - 'meta', run_hash, read_only=True, from_union=False, no_cache=True - ).subtree('meta') - dest_meta_tree = dest_repo.request_tree( - 'meta', run_hash, read_only=False, from_union=False, no_cache=True - ).subtree('meta') - dest_meta_run_tree = dest_meta_tree.subtree('chunks').subtree(run_hash) + source_meta_tree = self.request_tree('meta', run_hash, read_only=True).subtree('meta') + dest_meta_tree = dest_repo.request_tree('meta', run_hash, read_only=False).subtree('meta') dest_meta_tree[...] = source_meta_tree[...] - dest_index = dest_repo._get_index_tree('meta', timeout=10).view(()) - dest_meta_run_tree.finalize(index=dest_index) # copy run series tree - source_series_run_tree = self.request_tree('seqs', run_hash, read_only=True, no_cache=True).subtree('seqs') - dest_series_run_tree = dest_repo.request_tree('seqs', run_hash, read_only=False, no_cache=True).subtree( - 'seqs' - ) + source_series_run_tree = self.request_tree('seqs', run_hash, read_only=True).subtree('seqs') + dest_series_run_tree = dest_repo.request_tree('seqs', run_hash, read_only=False).subtree('seqs') # copy v2 sequences source_v2_tree = source_series_run_tree.subtree(('v2', 'chunks', run_hash)) @@ -985,7 +941,7 @@ def _backup_run(self, run_hash): from aim.sdk.utils import backup_run if self.is_remote_repo: - self._remote_repo_proxy._restore_run(run_hash) # noqa + self._remote_repo_proxy._restore_run(run_hash) else: backup_run(self, run_hash) @@ -993,11 +949,15 @@ def _restore_run(self, run_hash): from aim.sdk.utils import restore_run_backup if self.is_remote_repo: - self._remote_repo_proxy._restore_run(run_hash) # noqa + self._remote_repo_proxy._restore_run(run_hash) else: restore_run_backup(self, run_hash) def _close_run(self, run_hash): + import datetime + + import pytz + def optimize_container(path, extra_options): rc = RocksContainer(path, read_only=True, **extra_options) rc.optimize_for_read() @@ -1005,19 +965,24 @@ def optimize_container(path, extra_options): if self.is_remote_repo: self._remote_repo_proxy._close_run(run_hash) - from aim.sdk.index_manager import RepoIndexManager - lock_manager = LockManager(self.path) - index_manager = RepoIndexManager.get_index_manager(self) if lock_manager.release_locks(run_hash, force=True): + # Set run end time if locks are removed + meta_tree = self.request_tree( + 'meta', + run_hash, + read_only=False, + ).subtree('meta') + meta_run_tree = meta_tree.subtree('chunks').subtree(run_hash) + if not meta_run_tree.get('end_time'): + meta_run_tree['end_time'] = datetime.datetime.now(pytz.utc).timestamp() + # Run rocksdb optimizations if container locks are removed meta_db_path = os.path.join(self.path, 'meta', 'chunks', run_hash) seqs_db_path = os.path.join(self.path, 'seqs', 'chunks', run_hash) optimize_container(meta_db_path, extra_options={'compaction': True}) optimize_container(seqs_db_path, extra_options={}) - if index_manager.run_needs_indexing(run_hash): - index_manager.index(run_hash) def _recreate_index(self): from tqdm import tqdm diff --git a/aim/sdk/run.py b/aim/sdk/run.py index 08b89b8c37..b53bdf72af 100644 --- a/aim/sdk/run.py +++ b/aim/sdk/run.py @@ -82,9 +82,9 @@ def __init__(self, instance: 'Run') -> None: def add_extra_resource(self, resource) -> None: self.extra_resources.append(resource) - def finalize_run(self): + def set_run_end_time(self): """ - Finalize the run by indexing all the data. + Set Run end_time to mark it as finished. """ self.meta_run_tree['end_time'] = datetime.datetime.now(pytz.utc).timestamp() @@ -94,7 +94,7 @@ def empty_rpc_queue(self): def _close(self) -> None: """ - Close the `Run` instance resources and trigger indexing. + Close the `Run` instance resources. """ if self.read_only: logger.debug(f'Run {self.hash} is read-only, skipping cleanup') @@ -104,7 +104,7 @@ def _close(self) -> None: res.close() self.empty_rpc_queue() - self.finalize_run() + self.set_run_end_time() if self._heartbeat is not None: self._heartbeat.stop() if self._checkins is not None: @@ -287,7 +287,7 @@ def __init__( raise RuntimeError else: logger.warning(f'Detected sub-optimal format metrics for Run {self.hash}. Upgrading...') - backup_path = backup_run(self) + backup_path = backup_run(self.repo, self.hash) try: self.update_metrics() logger.warning(f'Successfully converted Run {self.hash}') @@ -567,14 +567,6 @@ def get_metric(self, name: str, context: Context) -> Optional['Metric']: Returns: :obj:`Metric` object if exists, `None` otherwise. """ - if self.read_only and not Run._metric_version_warning_shown: - if self.check_metrics_version(): - logger.warning( - f'Detected sub-optimal format metrics for Run {self.hash}. Consider upgrading repo ' - f'to improve queries performance:' - ) - logger.warning(f"aim storage --repo {self.repo.path} upgrade 3.11+ '*'") - Run._metric_version_warning_shown = True return self._get_sequence('metric', name, context) @@ -733,12 +725,6 @@ def close(self): self._props = None self._cleanup_trees() - def finalize(self): - if self._resources is None: - return - - self._resources.finalize_run() - def dataframe( self, include_props: bool = True, diff --git a/aim/sdk/run_status_manager.py b/aim/sdk/run_status_manager.py new file mode 100644 index 0000000000..e1fa6f3fc4 --- /dev/null +++ b/aim/sdk/run_status_manager.py @@ -0,0 +1,95 @@ +import datetime +import os +import threading +import time + +from pathlib import Path +from typing import Iterable + +import aimrocks.errors +import pytz + +from aim import Repo +from aim.sdk.run_status_watcher import Event + + +class RunStatusManager: + INDEXING_GRACE_PERIOD = 10 + + def __init__(self, repo: Repo, scan_interval: int = 60): + self.repo = repo + self.scan_interval = scan_interval + + self.progress_dir = Path(self.repo.path) / 'meta' / 'progress' + self.progress_dir.mkdir(parents=True, exist_ok=True) + + self.heartbeat_dir = Path(self.repo.path) / 'check_ins' + self.run_heartbeat_cache = {} + + self._stop_event = threading.Event() + self._monitor_thread = None + self._corrupted_runs = set() + + def start(self): + if not self._monitor_thread or not self._monitor_thread.is_alive(): + self._stop_event.clear() + self._monitor_thread = threading.Thread(target=self._run_forever, daemon=True) + self._monitor_thread.start() + + def stop(self): + self._stop_event.set() + if self._monitor_thread: + self._monitor_thread.join() + + def _run_forever(self): + while not self._stop_event.is_set(): + self.check_and_terminate_stalled_runs() + time.sleep(self.scan_interval) + + def _runs_with_progress(self) -> Iterable[str]: + runs_with_progress = filter(lambda x: x not in self._corrupted_runs, os.listdir(self.progress_dir)) + run_hashes = sorted(runs_with_progress, key=lambda r: os.path.getmtime(os.path.join(self.progress_dir, r))) + return run_hashes + + def check_and_terminate_stalled_runs(self): + for run_hash in self._runs_with_progress(): + if self._is_run_stalled(run_hash): + self._mark_run_as_terminated(run_hash) + + def _is_run_stalled(self, run_hash: str) -> bool: + stalled = False + + heartbeat_files = list(sorted(self.heartbeat_dir.glob(f'{run_hash}-*-progress-*-*'), reverse=True)) + if heartbeat_files: + latest_file = heartbeat_files[0].name + last_heartbeat = Event(latest_file) + + last_recorded_heartbeat = self.run_heartbeat_cache.get(run_hash) + if last_recorded_heartbeat is None: + # First time seeing a heartbeat for this run; store and move on + self.run_heartbeat_cache[run_hash] = last_heartbeat + elif last_heartbeat.idx > last_recorded_heartbeat.idx: + # Newer heartbeat arrived, so the run isn't stalled + self.run_heartbeat_cache[run_hash] = last_heartbeat + else: + # No new heartbeat event since last time; check if enough time passed + time_passed = time.time() - last_recorded_heartbeat.detected_epoch_time + if (last_recorded_heartbeat.next_event_in + RunStatusManager.INDEXING_GRACE_PERIOD) < time_passed: + stalled = True + else: + stalled = True + + return stalled + + def _mark_run_as_terminated(self, run_hash: str): + # TODO [AT]: Add run state handling once decided on terms (finished, terminated, aborted, etc.) + try: + meta_run_tree = self.repo.request_tree('meta', run_hash, read_only=False).subtree( + ('meta', 'chunks', run_hash) + ) + if meta_run_tree.get('end_time') is None: + meta_run_tree['end_time'] = datetime.datetime.now(pytz.utc).timestamp() + progress_path = self.progress_dir / run_hash + progress_path.unlink(missing_ok=True) + except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): + self._corrupted_runs.add(run_hash) diff --git a/aim/sdk/sequence.py b/aim/sdk/sequence.py index de8c78e1d1..dde9e215f0 100644 --- a/aim/sdk/sequence.py +++ b/aim/sdk/sequence.py @@ -201,7 +201,7 @@ def numpy(self) -> Tuple[np.ndarray, List[np.ndarray]]: sort_indices = steps.argsort() columns = [arr[sort_indices] for arr in columns] steps = steps[sort_indices] - if last_step is not None and last_step != steps[-1]: + if last_step is not None and last_step > steps[-1]: step_hash = self.step_hash(last_step) # The `last_step` is provided by the meta tree which may potentially # be out of sync with the series tree. diff --git a/aim/sdk/sequence_collection.py b/aim/sdk/sequence_collection.py index 5738a8e280..62c083d45a 100644 --- a/aim/sdk/sequence_collection.py +++ b/aim/sdk/sequence_collection.py @@ -3,6 +3,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Iterator +from aim.sdk.query_analyzer import QueryExpressionTransformer from aim.sdk.query_utils import RunView, SequenceView from aim.sdk.sequence import Sequence from aim.sdk.types import QueryReportMode @@ -170,20 +171,37 @@ def iter_runs(self) -> Iterator['SequenceCollection']: if self.report_mode == QueryReportMode.PROGRESS_BAR: progress_bar = tqdm(total=total_runs) + seq_var = self.seq_cls.sequence_name() + t = QueryExpressionTransformer( + var_names=[ + seq_var, + ] + ) + run_expr, is_transformed = t.transform(self.query) + run_query = RestrictedPythonQuery(run_expr) + for run in runs_iterator: - seq_collection = SingleRunSequenceCollection( - run, - self.seq_cls, - self.query, - runs_proxy_cache=self.runs_proxy_cache, - timezone_offset=self._timezone_offset, - ) - if self.report_mode == QueryReportMode.PROGRESS_TUPLE: - yield seq_collection, (runs_counter, total_runs) - else: - if self.report_mode == QueryReportMode.PROGRESS_BAR: - progress_bar.update(1) - yield seq_collection + check_run_sequences = True + if is_transformed: + run_view = RunView(run, runs_proxy_cache=self.runs_proxy_cache, timezone_offset=self._timezone_offset) + match = run_query.check(**{'run': run_view}) + if not match: + check_run_sequences = False + + if check_run_sequences: + seq_collection = SingleRunSequenceCollection( + run, + self.seq_cls, + self.query, + runs_proxy_cache=self.runs_proxy_cache, + timezone_offset=self._timezone_offset, + ) + if self.report_mode == QueryReportMode.PROGRESS_TUPLE: + yield seq_collection, (runs_counter, total_runs) + else: + if self.report_mode == QueryReportMode.PROGRESS_BAR: + progress_bar.update(1) + yield seq_collection runs_counter += 1 def iter(self) -> Iterator[Sequence]: diff --git a/aim/sdk/sequences/figure_sequence.py b/aim/sdk/sequences/figure_sequence.py index ff6081e60f..885828f79a 100644 --- a/aim/sdk/sequences/figure_sequence.py +++ b/aim/sdk/sequences/figure_sequence.py @@ -9,7 +9,7 @@ class Figures(Sequence): @classmethod def allowed_dtypes(cls) -> Union[str, Tuple[str, ...]]: - return (Figure.get_typename(),) # noqa : need a tuple for consitancy + return (Figure.get_typename(),) # need a tuple for consitancy @classmethod def sequence_name(cls) -> str: diff --git a/aim/sdk/types.py b/aim/sdk/types.py index 51fdc72cd8..aa70e24ec2 100644 --- a/aim/sdk/types.py +++ b/aim/sdk/types.py @@ -1,6 +1,7 @@ -from aim.storage.types import * # noqa F401 from enum import Enum +from aim.storage.types import * # noqa: F403 + class QueryReportMode(Enum): DISABLED = 0 diff --git a/aim/sdk/uri_service.py b/aim/sdk/uri_service.py index 10d588918e..062c05ac60 100644 --- a/aim/sdk/uri_service.py +++ b/aim/sdk/uri_service.py @@ -55,7 +55,7 @@ def request_batch(self, uri_batch: List[str]) -> Iterator[Dict[str, bytes]]: for uri, sub_name, resource_path in self.runs_pool[run_name]: container = run_containers.get(sub_name) if not container: - container = self._get_container(run_name, sub_name) + container = self.repo.request_container(sub_name, run_name, read_only=True) run_containers[sub_name] = container resource_path = decode_path(bytes.fromhex(resource_path)) @@ -70,11 +70,3 @@ def request_batch(self, uri_batch: List[str]) -> Iterator[Dict[str, bytes]]: # clear runs pool self.runs_pool.clear() - - def _get_container(self, run_name: str, sub_name: str): - if sub_name == 'meta': - container = self.repo.request(sub_name, run_name, from_union=True, read_only=True) - else: - container = self.repo.request(sub_name, run_name, read_only=True) - - return container diff --git a/aim/sdk/utils.py b/aim/sdk/utils.py index 6863600f94..0e5ff84faa 100644 --- a/aim/sdk/utils.py +++ b/aim/sdk/utils.py @@ -165,13 +165,12 @@ def flatten(d, parent_path=None): return all_paths subtrees_to_lookup = ('attrs', 'traces_types', 'contexts', 'traces') - repo_meta_tree = repo._get_meta_tree() # set of all repo paths that can be left dangling after run deletion repo_paths = set() for key in subtrees_to_lookup: try: - repo_paths.update(flatten(repo_meta_tree.collect(key, strict=False), parent_path=(key,))) + repo_paths.update(flatten(repo.meta_tree.collect(key, strict=False), parent_path=(key,))) except KeyError: pass @@ -179,7 +178,7 @@ def flatten(d, parent_path=None): for run_hash in tqdm(run_hashes): # construct unique paths set for each run run_paths = set() - run_meta_tree = repo.request_tree('meta', run_hash, from_union=False, read_only=True).subtree('meta') + run_meta_tree = repo.request_tree('meta', run_hash, read_only=True).subtree('meta') for key in subtrees_to_lookup: try: run_paths.update(flatten(run_meta_tree.collect(key, strict=False), parent_path=(key,))) diff --git a/aim/storage/artifacts/s3_storage.py b/aim/storage/artifacts/s3_storage.py index bc24c7372d..d30951bb18 100644 --- a/aim/storage/artifacts/s3_storage.py +++ b/aim/storage/artifacts/s3_storage.py @@ -73,7 +73,7 @@ def _get_s3_client(self): return client -def S3ArtifactStorage_factory(**boto3_client_kwargs: dict): +def S3ArtifactStorage_factory(**boto3_client_kwargs): class S3ArtifactStorageCustom(S3ArtifactStorage): def _get_s3_client(self): import boto3 @@ -88,7 +88,7 @@ def _get_s3_client(self): return S3ArtifactStorageCustom -def S3ArtifactStorage_clientconfig(**boto3_client_kwargs: dict): +def S3ArtifactStorage_clientconfig(**boto3_client_kwargs): from aim.storage.artifacts import registry registry.registry['s3'] = S3ArtifactStorage_factory(**boto3_client_kwargs) diff --git a/aim/storage/encoding/encoding.pyx b/aim/storage/encoding/encoding.pyx index 71f2ca40bd..308627e01d 100644 --- a/aim/storage/encoding/encoding.pyx +++ b/aim/storage/encoding/encoding.pyx @@ -22,7 +22,7 @@ from aim.storage.encoding.encoding_native cimport ( decode_double, decode_utf_8_str, ) -from aim.storage.encoding.encoding_native cimport decode_path # noqa F401 +from aim.storage.encoding.encoding_native cimport decode_path # noqa: F401 from aim.storage.utils import ArrayFlagType, ObjectFlagType, CustomObjectFlagType from aim.storage.utils import ArrayFlag, ObjectFlag from aim.storage.container import ContainerValue diff --git a/aim/storage/hashing/hashing.py b/aim/storage/hashing/hashing.py index 1aaa7e52e5..eef53c5d2c 100644 --- a/aim/storage/hashing/hashing.py +++ b/aim/storage/hashing/hashing.py @@ -11,7 +11,7 @@ from typing import Tuple, Union -from aim.storage.encoding import decode_int64, encode_int64 # noqa +from aim.storage.encoding import decode_int64, encode_int64 from aim.storage.hashing import c_hash from aim.storage.types import ( AimObject, diff --git a/aim/storage/migrations/versions/661514b12ee1_.py b/aim/storage/migrations/versions/661514b12ee1_.py new file mode 100644 index 0000000000..eacfccfe85 --- /dev/null +++ b/aim/storage/migrations/versions/661514b12ee1_.py @@ -0,0 +1,69 @@ +"""empty message + +Revision ID: 661514b12ee1 +Revises: 46b89d830ad8 +Create Date: 2025-06-05 19:52:31.221392 + +""" +from alembic import op +import sqlalchemy as sa +from alembic.context import get_context + + +# revision identifiers, used by Alembic. +revision = '661514b12ee1' +down_revision = '46b89d830ad8' +branch_labels = None +depends_on = None + + + +def upgrade(): + # Get the SQLite connection context + context = get_context() + naming_convention = { + "fk": + "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + } + # Use batch operations for SQLite + with op.batch_alter_table('run_tag', naming_convention=naming_convention) as batch_op: + # First drop the existing foreign key + batch_op.drop_constraint('fk_run_tag_run_id_run', type_='foreignkey') + batch_op.drop_constraint('fk_run_tag_tag_id_tag', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_run_tag_run_id_run', 'run', ['run_id'], ['id'], ondelete='CASCADE') + batch_op.create_foreign_key('fk_run_tag_tag_id_tag', 'tag', ['tag_id'], ['id'], ondelete='CASCADE') + + + with op.batch_alter_table('note', naming_convention=naming_convention) as batch_op: + # First drop the existing foreign key + batch_op.drop_constraint('fk_note_run_id_run', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_note_run_id_run', 'run', ['run_id'], ['id'], ondelete='CASCADE') + + +def downgrade(): + # Use batch operations for SQLite + naming_convention = { + "fk": + "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + } + # Use batch operations for SQLite + with op.batch_alter_table('run_tag', naming_convention=naming_convention) as batch_op: + # Drop the CASCADE foreign key + batch_op.drop_constraint('fk_run_tag_run_id_run', type_='foreignkey') + batch_op.drop_constraint('fk_run_tag_tag_id_tag', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_run_tag_run_id_run', 'run', ['run_id'], ['id'],) + batch_op.create_foreign_key('fk_run_tag_tag_id_tag', 'tag', ['tag_id'], ['id'],) + + with op.batch_alter_table('note', naming_convention=naming_convention) as batch_op: + # First drop the existing foreign key + batch_op.drop_constraint('fk_note_run_id_run', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_note_run_id_run', 'run', ['run_id'], ['id'],) + diff --git a/aim/storage/proxy.py b/aim/storage/proxy.py index d9c62c8c2e..cf3d84fca5 100644 --- a/aim/storage/proxy.py +++ b/aim/storage/proxy.py @@ -174,7 +174,7 @@ def __call__(self): if self.cache is not None and cache_key is not None: self.cache[cache_key] = res - return res + return res class AimObjectProxy(with_metaclass(_ObjectProxyMetaType)): @@ -192,8 +192,8 @@ def __name__(self, value): def __class__(self): return self.__wrapped__().__class__ - @__class__.setter # noqa - def __class__(self, value): # noqa + @__class__.setter + def __class__(self, value): self.__wrapped__().__class__ = value @property diff --git a/aim/storage/query.py b/aim/storage/query.py index 0ada6f1534..82de23657a 100644 --- a/aim/storage/query.py +++ b/aim/storage/query.py @@ -52,7 +52,7 @@ def safer_getattr(object, name, default=None, getattr=getattr): if name == 'format' and isinstance(object, str): raise NotImplementedError('Using format() on a %s is not safe.' % object.__class__.__name__) if name[0] == '_': - raise AttributeError('"{name}" is an invalid attribute name because it ' 'starts with "_"'.format(name=name)) + raise AttributeError('"{name}" is an invalid attribute name because it starts with "_"'.format(name=name)) val = getattr(object, name, default) return val diff --git a/aim/storage/rockscontainer.pyx b/aim/storage/rockscontainer.pyx index 1be6f9086a..e96fc4b42f 100644 --- a/aim/storage/rockscontainer.pyx +++ b/aim/storage/rockscontainer.pyx @@ -35,6 +35,7 @@ class RocksAutoClean(AutoClean): super().__init__(instance) self._lock = None self._db = None + self._progress_path = None def _close(self): """ @@ -48,6 +49,9 @@ class RocksAutoClean(AutoClean): self._db = None self._lock.release() self._lock = None + if self._progress_path is not None: + self._progress_path.unlink(missing_ok=True) + self._progress_path = None if self._db is not None: self._db = None @@ -104,6 +108,7 @@ class RocksContainer(Container): if not self.read_only: progress_dir.mkdir(parents=True, exist_ok=True) self._progress_path.touch(exist_ok=True) + self._resources._progress_path = self._progress_path self.db # TODO check if Containers are reopenable @@ -144,7 +149,7 @@ class RocksContainer(Container): lock_cls = self.get_lock_cls() self._lock = lock_cls(self._lock_path, timeout) self._lock.acquire() - else: + elif not self._extra_opts.get('skip_read_optimization', False): self.optimize_for_read() self._db = aimrocks.DB(str(self.path), @@ -159,16 +164,9 @@ class RocksContainer(Container): Store the collection of `(key, value)` records in the :obj:`Container` `index` for fast reads. """ - if not self._progress_path: - return - for k, v in self.items(): index[k] = v - if self._progress_path.exists(): - self._progress_path.unlink() - self._progress_path = None - def close(self): """Close all the resources.""" if self._resources is None: diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index f8acff5bb5..a58b9a9eca 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -9,7 +9,7 @@ ) from aim.storage.types import SafeNone from aim.web.configs import AIM_LOG_LEVEL_KEY -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event from sqlalchemy.orm import scoped_session, sessionmaker import aim.storage.drop_table_cascade # noqa: F401 @@ -66,7 +66,10 @@ def __init__(self, path: str, readonly: bool = False): self.db_url, echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), pool_pre_ping=True + # pool_size=10, + # max_overflow=20, ) + event.listen(self.engine, 'connect', lambda c, _: c.execute('pragma foreign_keys=on')) self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) self._upgraded = None diff --git a/aim/storage/structured/sql_engine/entities.py b/aim/storage/structured/sql_engine/entities.py index 84c72158cb..21d08626f6 100644 --- a/aim/storage/structured/sql_engine/entities.py +++ b/aim/storage/structured/sql_engine/entities.py @@ -1,3 +1,5 @@ +import logging + from typing import Collection, List, Optional, Union import pytz @@ -26,6 +28,9 @@ from sqlalchemy.orm import joinedload +logger = logging.getLogger(__name__) + + def session_commit_or_flush(session): if getattr(session, 'autocommit', True) and sa_version >= '2.0.0': session.commit() @@ -82,11 +87,9 @@ def from_hash(cls, runhash: str, created_at, session) -> 'ModelMappedRun': @classmethod def delete_run(cls, runhash: str, session) -> bool: - try: - rows_affected = session.query(RunModel).filter(RunModel.hash == runhash).delete() - session_commit_or_flush(session) - except Exception: - return False + rows_affected = session.query(RunModel).filter(RunModel.hash == runhash).delete() + session_commit_or_flush(session) + return rows_affected > 0 @classmethod @@ -227,6 +230,10 @@ def unsafe_add_tag(): self._model.tags.append(tag) session.add(self._model) + if value in self.tags: + logger.warning(f'Tag with value: {value} is already present in this run.') + return + session = self._session unsafe_add_tag() try: diff --git a/aim/storage/structured/sql_engine/models.py b/aim/storage/structured/sql_engine/models.py index 1c78c539e7..9859a8d853 100644 --- a/aim/storage/structured/sql_engine/models.py +++ b/aim/storage/structured/sql_engine/models.py @@ -28,7 +28,7 @@ def default_to_run_hash(context): run_tags = Table( 'run_tag', Base.metadata, - Column('run_id', Integer, ForeignKey('run.id'), primary_key=True, nullable=False), + Column('run_id', Integer, ForeignKey('run.id', ondelete='CASCADE'), primary_key=True, nullable=False), Column('tag_id', Integer, ForeignKey('tag.id'), primary_key=True, nullable=False), ) @@ -51,7 +51,9 @@ class Run(Base): experiment_id = Column(ForeignKey('experiment.id'), nullable=True) experiment = relationship('Experiment', backref=backref('runs', uselist=True, order_by='Run.created_at.desc()')) - tags = relationship('Tag', secondary=run_tags, backref=backref('runs', uselist=True)) + tags = relationship( + 'Tag', secondary=run_tags, backref=backref('runs', uselist=True), cascade='all, delete', passive_deletes=True + ) notes = relationship('Note', back_populates='run') def __init__(self, run_hash, created_at=None): @@ -106,7 +108,7 @@ class Note(Base): id = Column(Integer, autoincrement=True, primary_key=True) content = Column(Text, nullable=False, default='') - run_id = Column(Integer, ForeignKey('run.id')) + run_id = Column(Integer, ForeignKey('run.id', ondelete='CASCADE'),) experiment_id = Column(Integer, ForeignKey('experiment.id')) created_at = Column(DateTime, default=datetime.datetime.utcnow) diff --git a/aim/storage/treeviewproxy.py b/aim/storage/treeviewproxy.py index f459a096a7..d2e188e84f 100644 --- a/aim/storage/treeviewproxy.py +++ b/aim/storage/treeviewproxy.py @@ -24,8 +24,6 @@ def __init__( sub: str, *, read_only: bool, - from_union: bool = False, - no_cache: bool = False, index=False, timeout=None, ): @@ -38,10 +36,8 @@ def __init__( 'name': name, 'sub': sub, 'read_only': read_only, - 'from_union': from_union, 'index': index, 'timeout': timeout, - 'no_cache': no_cache, } self.init_args = pack_args(encode_tree(kwargs)) self.resource_type = 'TreeView' diff --git a/aim/storage/types.py b/aim/storage/types.py index a21caa0611..6fbf6e0129 100644 --- a/aim/storage/types.py +++ b/aim/storage/types.py @@ -1,6 +1,8 @@ -from aim.storage.utils import BLOB # noqa F401 from typing import Dict, List, Tuple, Union +from aim.storage.utils import BLOB as BLOB + + NoneType = type(None) diff --git a/aim/storage/union.pyx b/aim/storage/union.pyx index 2d5729c753..e9bafc5773 100644 --- a/aim/storage/union.pyx +++ b/aim/storage/union.pyx @@ -242,11 +242,8 @@ class DB(object): index_db = None logger.info('No index was detected') - # If index exists -- only load those in progress - selector = 'progress' if index_db is not None else 'chunks' - new_dbs: Dict[bytes, aimrocks.DB] = {} - db_dir = os.path.join(self.db_path, self.db_name, selector) + db_dir = os.path.join(self.db_path, self.db_name, 'chunks') for prefix in self._list_dir(db_dir): path = os.path.join(self.db_path, self.db_name, "chunks", prefix) prefix = encode_path((self.db_name, "chunks", prefix)) diff --git a/aim/tensorflow.py b/aim/tensorflow.py index 93ccaed16a..17cee65496 100644 --- a/aim/tensorflow.py +++ b/aim/tensorflow.py @@ -1,2 +1,3 @@ # Alias to SDK TensorFlow Keras interface -from aim.sdk.adapters.tensorflow import AimCallback, AimTracker # noqa F401 +from aim.sdk.adapters.tensorflow import AimCallback as AimCallback +from aim.sdk.adapters.tensorflow import AimTracker as AimTracker diff --git a/aim/utils/__init__.py b/aim/utils/__init__.py index 761d0cd34d..c48598750b 100644 --- a/aim/utils/__init__.py +++ b/aim/utils/__init__.py @@ -1 +1,2 @@ -from aim.ext.exception_resistant import enable_safe_mode, disable_safe_mode # noqa +from aim.ext.exception_resistant import disable_safe_mode as disable_safe_mode +from aim.ext.exception_resistant import enable_safe_mode as enable_safe_mode diff --git a/aim/utils/deprecation.py b/aim/utils/deprecation.py index 46bd86266a..b0bf2a70fd 100644 --- a/aim/utils/deprecation.py +++ b/aim/utils/deprecation.py @@ -10,11 +10,11 @@ def python_version_deprecation_check(): import sys version_info = sys.version_info - if version_info.major == 3 and version_info.minor == 6: + if version_info.major == 3 and version_info.minor == 7: deprecation_warning( - remove_version='3.16', - msg='Python 3.6 has reached EOL. Aim support for Python 3.6 is deprecated!', - remove_msg_template='Python 3.6 support will be dropped in', + remove_version='3.30', + msg='Python 3.7 has reached EOL. Aim support for Python 3.7 is deprecated!', + remove_msg_template='Python 3.7 support will be dropped in', ) diff --git a/aim/web/api/__init__.py b/aim/web/api/__init__.py index 52b5095b06..cb553a1e44 100644 --- a/aim/web/api/__init__.py +++ b/aim/web/api/__init__.py @@ -23,7 +23,6 @@ def create_app(): max_age=86400, ) - from aim.sdk.index_manager import RepoIndexManager from aim.web.api.dashboard_apps.views import dashboard_apps_router from aim.web.api.dashboards.views import dashboards_router from aim.web.api.experiments.views import experiment_router @@ -36,11 +35,6 @@ def create_app(): from aim.web.api.views import statics_router from aim.web.configs import AIM_UI_BASE_PATH - # The indexing thread has to run in the same process as the uvicorn app itself. - # This allows sharing state of indexing using memory instead of process synchronization methods. - index_mng = RepoIndexManager.get_index_manager(Project().repo) - index_mng.start_indexing_thread() - api_app = FastAPI() api_app.add_middleware(GZipMiddleware, compresslevel=1) api_app.add_middleware(ResourceCleanupMiddleware) diff --git a/aim/web/api/dashboard_apps/views.py b/aim/web/api/dashboard_apps/views.py index 50fb871238..7f1acae268 100644 --- a/aim/web/api/dashboard_apps/views.py +++ b/aim/web/api/dashboard_apps/views.py @@ -19,7 +19,7 @@ @dashboard_apps_router.get('/', response_model=ExploreStateListOut) async def dashboard_apps_list_api(session: Session = Depends(get_session)): - explore_states = session.query(ExploreState).filter(ExploreState.is_archived == False) # noqa + explore_states = session.query(ExploreState).filter(ExploreState.is_archived == False) # noqa: E712 result = [] for es in explore_states: result.append(explore_state_response_serializer(es)) diff --git a/aim/web/api/dashboards/views.py b/aim/web/api/dashboards/views.py index fe12bf69d1..5cd16c3020 100644 --- a/aim/web/api/dashboards/views.py +++ b/aim/web/api/dashboards/views.py @@ -19,7 +19,7 @@ @dashboards_router.get('/', response_model=List[DashboardOut]) async def dashboards_list_api(session: Session = Depends(get_session)): - dashboards_query = session.query(Dashboard).filter(Dashboard.is_archived == False).order_by(Dashboard.updated_at) # noqa + dashboards_query = session.query(Dashboard).filter(Dashboard.is_archived == False).order_by(Dashboard.updated_at) # noqa: E712 result = [] for dashboard in dashboards_query: @@ -50,7 +50,7 @@ async def dashboards_post_api(request_data: DashboardCreateIn, session: Session @dashboards_router.get('/{dashboard_id}/', response_model=DashboardOut) async def dashboards_get_api(dashboard_id: str, session: Session = Depends(get_session)): - dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa + dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa: E712 if not dashboard: raise HTTPException(status_code=404) @@ -61,7 +61,7 @@ async def dashboards_get_api(dashboard_id: str, session: Session = Depends(get_s async def dashboards_put_api( dashboard_id: str, request_data: DashboardUpdateIn, session: Session = Depends(get_session) ): - dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa + dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa: E712 if not dashboard: raise HTTPException(status_code=404) dashboard_name = request_data.name @@ -77,7 +77,7 @@ async def dashboards_put_api( @dashboards_router.delete('/{dashboard_id}/') async def dashboards_delete_api(dashboard_id: str, session: Session = Depends(get_session)): - dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa + dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa: E712 if not dashboard: raise HTTPException(status_code=404) diff --git a/aim/web/api/db.py b/aim/web/api/db.py index 182ad20919..562c840ca3 100644 --- a/aim/web/api/db.py +++ b/aim/web/api/db.py @@ -14,7 +14,11 @@ engine = create_engine( get_db_url(), echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), + # connect_args={'check_same_thread': False}, + # pool_size=10, + # max_overflow=20, ) + SessionLocal = sessionmaker(autoflush=False, bind=engine) Base = declarative_base() diff --git a/aim/web/api/experiments/views.py b/aim/web/api/experiments/views.py index 9164eb78f4..a47511cd9e 100644 --- a/aim/web/api/experiments/views.py +++ b/aim/web/api/experiments/views.py @@ -114,7 +114,7 @@ async def update_experiment_properties_api(exp_id: str, exp_in: ExperimentUpdate if exp_in.archived is not None: if exp_in.archived and len(exp.runs) > 0: raise HTTPException( - status_code=400, detail=(f"Cannot archive experiment '{exp_id}'. " 'Experiment has associated runs.') + status_code=400, detail=(f"Cannot archive experiment '{exp_id}'. Experiment has associated runs.") ) exp.archived = exp_in.archived diff --git a/aim/web/api/projects/project.py b/aim/web/api/projects/project.py index 2fa29ee7a7..b1ae57eba2 100644 --- a/aim/web/api/projects/project.py +++ b/aim/web/api/projects/project.py @@ -18,7 +18,6 @@ def __init__(self): def cleanup_repo_pools(self): self.repo.container_pool.clear() - self.repo.container_view_pool.clear() self.repo.persistent_pool.clear() def cleanup_sql_caches(self): diff --git a/aim/web/api/projects/views.py b/aim/web/api/projects/views.py index c32cc5452a..36856ba27a 100644 --- a/aim/web/api/projects/views.py +++ b/aim/web/api/projects/views.py @@ -5,7 +5,6 @@ from logging import getLogger from typing import Optional, Tuple -from aim.sdk.index_manager import RepoIndexManager from aim.storage.locking import AutoFileLock from aim.web.api.projects.project import Project from aim.web.api.projects.pydantic_models import ( @@ -171,13 +170,3 @@ async def project_params_api(sequence: Optional[Tuple[str, ...]] = Query(()), ex } response.update(**project.repo.collect_sequence_info(sequence)) return response - - -@projects_router.get('/status/') -async def project_status_api(): - project = Project() - - if not project.exists(): - raise HTTPException(status_code=404) - - return RepoIndexManager.get_index_manager(project.repo).repo_status diff --git a/aim/web/api/runs/object_views.py b/aim/web/api/runs/object_views.py index 8e5d5b3f5a..6a4b4d0e93 100644 --- a/aim/web/api/runs/object_views.py +++ b/aim/web/api/runs/object_views.py @@ -32,7 +32,7 @@ class CustomObjectApiConfig: sequence_type: type = Sequence resolve_blobs: bool = False - dump_record_fn: callable = lambda x: x.data # noqa E731 + dump_record_fn: callable = lambda x: x.data model: type = BaseModel @staticmethod @@ -165,7 +165,7 @@ class TextApiConfig(CustomObjectApiConfig): class DistributionApiConfig(CustomObjectApiConfig): sequence_type = Distributions resolve_blobs = True - dump_record_fn = lambda x: numpy_to_encodable(x.weights) # noqa E731 + dump_record_fn = lambda x: numpy_to_encodable(x.weights) # noqa: E731 model = DistributionInfo @@ -178,5 +178,5 @@ class AudioApiConfig(CustomObjectApiConfig): class FigureApiConfig(CustomObjectApiConfig): sequence_type = Figures resolve_blobs = True - dump_record_fn = lambda x: x.data # noqa E731 + dump_record_fn = lambda x: x.data # noqa: E731 model = FigureInfo diff --git a/aim/web/api/tags/views.py b/aim/web/api/tags/views.py index 7b4bfbbeca..ce0d32e998 100644 --- a/aim/web/api/tags/views.py +++ b/aim/web/api/tags/views.py @@ -41,7 +41,8 @@ async def search_tags_by_name_api(q: Optional[str] = '', factory=Depends(object_ 'id': tag.uuid, 'name': tag.name, 'color': tag.color, - 'description' 'run_count': len(tag.runs), + 'description': tag.description, + 'run_count': len(tag.runs), 'archived': tag.archived, } for tag in factory.search_tags(q.strip()) diff --git a/aim/web/middlewares/profiler/profiler.py b/aim/web/middlewares/profiler/profiler.py index 0956dc8b56..fc39f40061 100644 --- a/aim/web/middlewares/profiler/profiler.py +++ b/aim/web/middlewares/profiler/profiler.py @@ -61,7 +61,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: profiler = self.profiler(interval=self._profiler_interval) try: profiler.start() - except: # noqa + except: # noqa: E722 skip_profiling = True else: skip_profiling = False diff --git a/aim/web/ui/package.json b/aim/web/ui/package.json index 567b5732ab..99ebb2bb8e 100644 --- a/aim/web/ui/package.json +++ b/aim/web/ui/package.json @@ -1,6 +1,6 @@ { "name": "ui_v2", - "version": "3.27.0", + "version": "3.29.1", "private": true, "dependencies": { "@aksel/structjs": "^1.0.0", diff --git a/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx b/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx index 1cd4bb89d1..24b28e31f3 100644 --- a/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx +++ b/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx @@ -8,8 +8,6 @@ import { ANALYTICS_EVENT_KEYS } from 'config/analytics/analyticsKeysMap'; import * as analytics from 'services/analytics'; -import useRunMetricsBatch from '../hooks/useRunMetricsBatch'; - import GitInfoCard from './components/GitInfoCard'; import RunOverviewTabMetricsCard from './components/MetricsCard/RunOverviewTabMetricsCard'; import RunOverviewTabArtifactsCard from './components/ArtifactsCard/RunOverviewTabArtifactsCard'; @@ -28,11 +26,6 @@ function RunOverviewTab({ runData, runHash }: IRunOverviewTabProps) { const overviewSectionContentRef = React.useRef(null); const [containerHeight, setContainerHeight] = React.useState(0); - useRunMetricsBatch({ - runTraces: runData.runTraces, - runHash, - }); - React.useEffect(() => { analytics.pageView( ANALYTICS_EVENT_KEYS.runDetails.tabs['overview'].tabView, diff --git a/aim/web/ui/src/utils/aggregateGroupData.ts b/aim/web/ui/src/utils/aggregateGroupData.ts index 0a0acc69e4..2fa05644af 100644 --- a/aim/web/ui/src/utils/aggregateGroupData.ts +++ b/aim/web/ui/src/utils/aggregateGroupData.ts @@ -113,6 +113,17 @@ export function aggregateGroupData({ } } } + // add special case handling for single point metrics + if (trace.xValues.length === 1) { + const step = trace.xValues[0]; + let value = chartXValues.indexOf(step); + let y = trace.yValues[0]; + if (yValuesPerX.hasOwnProperty(value)) { + yValuesPerX[value].push(y); + } else { + yValuesPerX[value] = [y]; + } + } } } diff --git a/aim/web/ui/src/utils/app/alignMetricData.ts b/aim/web/ui/src/utils/app/alignMetricData.ts index 89129a173e..ab623fef31 100644 --- a/aim/web/ui/src/utils/app/alignMetricData.ts +++ b/aim/web/ui/src/utils/app/alignMetricData.ts @@ -44,21 +44,33 @@ export function alignByEpoch( epochs[epoch] = [metric.data.steps[i]]; } }); - + // Get unique epoch values (for ex. (1, 1.495) instead of (1, 1)), because the epochs can be duplicate + let xValues = [ + ...metric.data.epochs.map((epoch, i) => { + return ( + epoch + + (epochs[epoch].length > 1 + ? (0.99 / epochs[epoch].length) * + epochs[epoch].indexOf(metric.data.steps[i]) + : 0) + ); + }), + ]; + let yValues = [...metric.data.values]; + let pointsArray = []; + // Combine the x and y axis arrays into an array of points + for (let idx = 0; idx < xValues.length; idx++) { + pointsArray[idx] = [xValues[idx], yValues[idx]]; + } + // Sort the combined array based on the first element of the point (epoch) + pointsArray.sort(function (a, b) { + return a[0] - b[0]; + }); metric.data = { ...metric.data, - xValues: [ - ...metric.data.epochs.map((epoch, i) => { - return ( - epoch + - (epochs[epoch].length > 1 - ? (0.99 / epochs[epoch].length) * - epochs[epoch].indexOf(metric.data.steps[i]) - : 0) - ); - }), - ], - yValues: [...metric.data.values], + // Separate the x and y axis values back into xValues and yValues + xValues: pointsArray.map((point) => point[0]), + yValues: pointsArray.map((point) => point[1]), }; } } diff --git a/docs/source/ui/pages/bookmarks.md b/docs/source/ui/pages/bookmarks.md index 4d5e080b2e..900594023f 100644 --- a/docs/source/ui/pages/bookmarks.md +++ b/docs/source/ui/pages/bookmarks.md @@ -2,7 +2,7 @@ ### Overview -Use the Bookmarks to save the Aim Explorer state. This includes search query, aggregations and any other modifications applied to the explorer. The Bookmarks page is a list of [cards](#bookmark-card) to quickly access the explorer state with one click. +Use the Bookmarks to save the Aim Explorer state. This includes search query, aggregations and any other modifications applied to the explorer. The Bookmarks page is a list of [cards](#the-bookmark-card) to quickly access the explorer state with one click. bookmarks

 

diff --git a/examples/pytorch_lightning_track.py b/examples/pytorch_lightning_track.py index 9d9d8c98b5..3a4b23c7ce 100644 --- a/examples/pytorch_lightning_track.py +++ b/examples/pytorch_lightning_track.py @@ -4,7 +4,7 @@ if importlib.util.find_spec('lightning'): import lightning.pytorch as pl -elif importlib.util.find_spec('pytorch_lightning'): # noqa F401 +elif importlib.util.find_spec('pytorch_lightning'): # F401 import pytorch_lightning as pl else: raise RuntimeError( diff --git a/examples/pytorch_track.py b/examples/pytorch_track.py index 3927356cb1..3c68bd51d2 100644 --- a/examples/pytorch_track.py +++ b/examples/pytorch_track.py @@ -90,7 +90,7 @@ def forward(self, x): if i % 30 == 0: logging.info( - 'Epoch [{}/{}], Step [{}/{}], ' 'Loss: {:.4f}'.format( + 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch + 1, num_epochs, i + 1, total_step, loss.item() ) ) diff --git a/examples/pytorch_track_images.py b/examples/pytorch_track_images.py index adb693a2f3..bcf4f627d9 100644 --- a/examples/pytorch_track_images.py +++ b/examples/pytorch_track_images.py @@ -94,7 +94,7 @@ def forward(self, x): if i % 30 == 0: logging.info( - 'Epoch [{}/{}], Step [{}/{}], ' 'Loss: {:.4f}'.format( + 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch + 1, num_epochs, i + 1, total_step, loss.item() ) ) diff --git a/performance_tests/BASELINE b/performance_tests/BASELINE index 1f403fadb1..ebe22e9386 100644 --- a/performance_tests/BASELINE +++ b/performance_tests/BASELINE @@ -1,21 +1,21 @@ -test_collect_metrics_data_0 1.8545958518981933 -test_collect_metrics_data_1 1.9959398269653321 -test_collect_metrics_data_2 10.835494375228881 -test_collect_metrics_data_3 1.8672633171081543 -test_collect_runs_data_0 0.8988437175750732 -test_collect_runs_data_1 1.039186429977417 -test_collect_runs_data_2 3.469265604019165 -test_collect_runs_data_3 0.9486905574798584 -test_query_metrics_0 1.6766140460968018 -test_query_metrics_1 1.6763684749603271 -test_query_metrics_2 1.6051365375518798 -test_query_metrics_3 1.5391615390777589 -test_query_runs_0 0.8991998672485352 -test_query_runs_1 0.9259328842163086 -test_query_runs_2 0.839762544631958 -test_query_runs_3 0.832861852645874 -test_container_open 0.1440361499786377 -test_iterative_access 4.000607919692993 -test_random_access_0 0.663770055770874 -test_random_access_1 1.4745195388793946 -test_random_access_2 2.424658107757568 \ No newline at end of file +test_collect_metrics_data_0 0.3717397689819336 +test_collect_metrics_data_1 0.3963047981262207 +test_collect_metrics_data_2 2.7405614376068117 +test_collect_metrics_data_3 0.3710219860076904 +test_collect_runs_data_0 0.17322354316711425 +test_collect_runs_data_1 0.20246338844299316 +test_collect_runs_data_2 0.7970072269439697 +test_collect_runs_data_3 0.1911233901977539 +test_query_metrics_0 0.311903190612793 +test_query_metrics_1 0.3122593879699707 +test_query_metrics_2 0.3092495441436768 +test_query_metrics_3 0.288785982131958 +test_query_runs_0 0.17433061599731445 +test_query_runs_1 0.17484822273254394 +test_query_runs_2 0.17181901931762694 +test_query_runs_3 0.1616499423980713 +test_container_open 0.04026708602905273 +test_iterative_access 1.1857992172241212 +test_random_access_0 0.14068403244018554 +test_random_access_1 0.26419754028320314 +test_random_access_2 0.3941319942474365 \ No newline at end of file diff --git a/performance_tests/conftest.py b/performance_tests/conftest.py index a53afa2e78..95fa82005c 100644 --- a/performance_tests/conftest.py +++ b/performance_tests/conftest.py @@ -43,7 +43,7 @@ def pytest_sessionstart(session): _init_test_repos() else: # github actions performance tests on self hosted runner - os.chdir('/home/ubuntu/performance_logs/') + os.chdir('/Users/github/workers/perf-tests/actions-runner/_work/performance_logs') time.sleep(10) diff --git a/performance_tests/sdk/queries.py b/performance_tests/sdk/queries.py index 514ce55831..38b79bf3d2 100644 --- a/performance_tests/sdk/queries.py +++ b/performance_tests/sdk/queries.py @@ -1,7 +1,7 @@ -query_0 = 'run.hparams.benchmark == "glue" ' 'and run.hparams.dataset == "cola" ' 'and metric.context.subset != "train"' -query_1 = 'run.hparams.benchmark == "glue" ' 'and run.hparams.dataset == "cola"' +query_0 = 'run.hparams.benchmark == "glue" and run.hparams.dataset == "cola" and metric.context.subset != "train"' +query_1 = 'run.hparams.benchmark == "glue" and run.hparams.dataset == "cola"' query_2 = 'run.hparams.benchmark == "glue"' -query_3 = 'run.hparams.dataset == "cola" ' 'and run.experiment.name != "baseline-warp_4-cola"' +query_3 = 'run.hparams.dataset == "cola" and run.experiment.name != "baseline-warp_4-cola"' queries = { diff --git a/requirements.dev.txt b/requirements.dev.txt index 936433a7b6..b36a02f595 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,3 +1,3 @@ wheel >= 0.31.0 twine >= 1.11.0 -ruff == 0.3.3 +ruff == 0.9.2 diff --git a/ruff.toml b/ruff.toml index d164ced829..e23481fd09 100644 --- a/ruff.toml +++ b/ruff.toml @@ -11,7 +11,11 @@ exclude = [ inline-quotes = "single" [lint] -extend-select = ["I"] +extend-select = [ + "I", + "PGH004", # blanket-noqa + "RUF100", # unused-noqa +] [lint.isort] no-lines-before = ["future", "standard-library", "first-party"] diff --git a/setup.py b/setup.py index 1cfacd0ca1..983b380168 100644 --- a/setup.py +++ b/setup.py @@ -76,10 +76,14 @@ def package_files(directory): 'packaging>=15.0', 'python-dateutil', 'requests', + 'watchdog', 'websockets', 'boto3', ] +if sys.version_info.minor < 9: + REQUIRED += ['astunparse'] + class UploadCommand(Command): """Support setup.py upload.""" @@ -194,6 +198,7 @@ def cytonize_extensions(): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: Implementation :: PyPy', ], ext_modules=cytonize_extensions(), diff --git a/tests/requirements.txt b/tests/requirements.txt index e3135d0bf3..f2e592cb99 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,6 +2,7 @@ torch tensorflow deeplake<4.0.0 # update when proper documentation is available +azure-storage-blob # for deeplake # hub fastapi>=0.87.0 httpx