From 3cfb5a64c6c699c08b9a44fdea6b8ae8eb07cd69 Mon Sep 17 00:00:00 2001 From: Sidney Batchelder <44208509+sbatchelder@users.noreply.github.com> Date: Wed, 15 Jan 2025 07:04:08 -0600 Subject: [PATCH 01/30] [fix] Bad typing for S3ArtifactStorage_clientconfig args (#3276) --- aim/storage/artifacts/s3_storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aim/storage/artifacts/s3_storage.py b/aim/storage/artifacts/s3_storage.py index bc24c7372..d30951bb1 100644 --- a/aim/storage/artifacts/s3_storage.py +++ b/aim/storage/artifacts/s3_storage.py @@ -73,7 +73,7 @@ def _get_s3_client(self): return client -def S3ArtifactStorage_factory(**boto3_client_kwargs: dict): +def S3ArtifactStorage_factory(**boto3_client_kwargs): class S3ArtifactStorageCustom(S3ArtifactStorage): def _get_s3_client(self): import boto3 @@ -88,7 +88,7 @@ def _get_s3_client(self): return S3ArtifactStorageCustom -def S3ArtifactStorage_clientconfig(**boto3_client_kwargs: dict): +def S3ArtifactStorage_clientconfig(**boto3_client_kwargs): from aim.storage.artifacts import registry registry.registry['s3'] = S3ArtifactStorage_factory(**boto3_client_kwargs) From 4321b075b02eaf620ac4e97efa47e0adbf9cfdd6 Mon Sep 17 00:00:00 2001 From: Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:07:09 +0700 Subject: [PATCH 02/30] [docs] Fix pages/bookmarks section links (#3274) --- docs/source/ui/pages/bookmarks.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/ui/pages/bookmarks.md b/docs/source/ui/pages/bookmarks.md index 4d5e080b2..900594023 100644 --- a/docs/source/ui/pages/bookmarks.md +++ b/docs/source/ui/pages/bookmarks.md @@ -2,7 +2,7 @@ ### Overview -Use the Bookmarks to save the Aim Explorer state. This includes search query, aggregations and any other modifications applied to the explorer. The Bookmarks page is a list of [cards](#bookmark-card) to quickly access the explorer state with one click. +Use the Bookmarks to save the Aim Explorer state. This includes search query, aggregations and any other modifications applied to the explorer. The Bookmarks page is a list of [cards](#the-bookmark-card) to quickly access the explorer state with one click. bookmarks

 

From 7a98238c890be210f8be8009895d5589dd488ca8 Mon Sep 17 00:00:00 2001 From: Fabian Keller Date: Mon, 20 Jan 2025 08:17:42 +0100 Subject: [PATCH 03/30] [feat] Add `py.typed` marker to allow users to benefit from existing type annotations (#3281) add py.typed marker --- aim/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 aim/py.typed diff --git a/aim/py.typed b/aim/py.typed new file mode 100644 index 000000000..e69de29bb From eba27a9260d033810d7d4ddbebebe2fb0be81fd0 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Mon, 20 Jan 2025 18:41:04 +0400 Subject: [PATCH 04/30] [fix] Decrease client resources keep-alive time (#3279) --- CHANGELOG.md | 5 +++++ aim/ext/transport/heartbeat.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a8e16a51..38700a14a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +### Fixes: +- Decrease client resources keep-alive time (mihran113) + ## 3.27.0 Dec 18, 2024 ### Enhancements: diff --git a/aim/ext/transport/heartbeat.py b/aim/ext/transport/heartbeat.py index e8009576b..3346ccff2 100644 --- a/aim/ext/transport/heartbeat.py +++ b/aim/ext/transport/heartbeat.py @@ -118,7 +118,7 @@ def reset_responses(): class HeartbeatWatcher: - CLIENT_KEEP_ALIVE_TIME_DEFAULT = 30 * 60 # 30 minutes + CLIENT_KEEP_ALIVE_TIME_DEFAULT = 5 * 60 # 5 minutes def __init__(self, heartbeat_pool, keep_alive_time: Union[int, float] = CLIENT_KEEP_ALIVE_TIME_DEFAULT): self._heartbeat_pool = heartbeat_pool From 3ae23634682a6cc765d4d4cb11f13a356b5acc93 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Tue, 11 Feb 2025 16:58:22 +0400 Subject: [PATCH 05/30] [fix] Resolve issues on data points connection on epoch alignment (#3283) --- .github/workflows/nightly-release.yml | 2 +- .github/workflows/pull-request.yml | 16 ++++----- CHANGELOG.md | 1 + aim/web/api/tags/views.py | 3 +- aim/web/ui/src/utils/app/alignMetricData.ts | 38 ++++++++++++++------- 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 47b7b2eb1..43707bc87 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -40,7 +40,7 @@ jobs: - name: setup python uses: actions/setup-python@v2 with: - python-version: '3.7' + python-version: '3.8' architecture: x64 - name: install deps diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 152071b5e..348ab9349 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -24,14 +24,14 @@ on: - reopened - edited jobs: - validate-naming-convention: - name: Pull Request's title matches naming convention - runs-on: ubuntu-latest - steps: - - uses: deepakputhraya/action-pr-title@master - with: - regex: '^\[(?:feat|fix|doc|refactor|deprecation)\]\s[A-Z].*(? { + return ( + epoch + + (epochs[epoch].length > 1 + ? (0.99 / epochs[epoch].length) * + epochs[epoch].indexOf(metric.data.steps[i]) + : 0) + ); + }), + ]; + let yValues = [...metric.data.values]; + let pointsArray = []; + // Combine the x and y axis arrays into an array of points + for (let idx = 0; idx < xValues.length; idx++) { + pointsArray[idx] = [xValues[idx], yValues[idx]]; + } + // Sort the combined array based on the first element of the point (epoch) + pointsArray.sort(function (a, b) { + return a[0] - b[0]; + }); metric.data = { ...metric.data, - xValues: [ - ...metric.data.epochs.map((epoch, i) => { - return ( - epoch + - (epochs[epoch].length > 1 - ? (0.99 / epochs[epoch].length) * - epochs[epoch].indexOf(metric.data.steps[i]) - : 0) - ); - }), - ], - yValues: [...metric.data.values], + // Separate the x and y axis values back into xValues and yValues + xValues: pointsArray.map((point) => point[0]), + yValues: pointsArray.map((point) => point[1]), }; } } From c6e0c7f60a684e0bfe2e54ed28b93ed5e4fe4100 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:06:26 +0400 Subject: [PATCH 06/30] [fix] Correct indentation on query proxy object return statement (#3287) --- aim/storage/proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aim/storage/proxy.py b/aim/storage/proxy.py index d9c62c8c2..8d967837e 100644 --- a/aim/storage/proxy.py +++ b/aim/storage/proxy.py @@ -174,7 +174,7 @@ def __call__(self): if self.cache is not None and cache_key is not None: self.cache[cache_key] = res - return res + return res class AimObjectProxy(with_metaclass(_ObjectProxyMetaType)): From 2d9f3b8b0f4ef23fc94818847e48307dbc5f6564 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:24:08 +0400 Subject: [PATCH 07/30] [feat] Skip metrics check when run is known to yield false result (#3288) * [feat] Skip metrics check when run is known to yeld false result * [fix] Code style checks * [fix] More styling errors --- aim/sdk/query_analyzer.py | 151 +++++++++++++++++++++++++++++++++ aim/sdk/sequence_collection.py | 46 ++++++---- 2 files changed, 182 insertions(+), 15 deletions(-) create mode 100644 aim/sdk/query_analyzer.py diff --git a/aim/sdk/query_analyzer.py b/aim/sdk/query_analyzer.py new file mode 100644 index 000000000..1930bcf78 --- /dev/null +++ b/aim/sdk/query_analyzer.py @@ -0,0 +1,151 @@ +import ast + +from typing import Any, List, Tuple + + +class Unknown(ast.AST): + pass + + +Unknown = Unknown() # create a single instance of value node + + +class QueryExpressionTransformer(ast.NodeTransformer): + def __init__(self, *, var_names: List[str]): + self._var_names = var_names + + def transform(self, expr: str) -> Tuple[str, bool]: + node = ast.parse(expr, mode='eval') + transformed = self.visit(node) + if transformed is Unknown: + return expr, False + else: + return ast.unparse(transformed), True + + def visit_Expression(self, node: ast.Expression) -> Any: + node: ast.Expression = self.generic_visit(node) + if node.body is Unknown: + return Unknown + return node + + def visit_Expr(self, node: ast.Expr) -> Any: + node: ast.Expr = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Constant(self, node: ast.Constant) -> Any: + return node + + def visit_JoinedStr(self, node: ast.JoinedStr) -> Any: + node: ast.JoinedStr = self.generic_visit(node) + for val in node.values: + if val is Unknown: + return Unknown + return node + + def visit_FormattedValue(self, node: ast.FormattedValue) -> Any: + node: ast.FormattedValue = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in self._var_names: + return Unknown + else: + return node + + def visit_Compare(self, node: ast.Compare) -> Any: + node: ast.Compare = self.generic_visit(node) + if node.left is Unknown: + return Unknown + for comp in node.comparators: + if comp is Unknown: + return Unknown + return node + + def visit_List(self, node: ast.List) -> Any: + node: ast.List = self.generic_visit(node) + for sub in node.elts: + if sub is Unknown: + return Unknown + return node + + def visit_Tuple(self, node: ast.Tuple) -> Any: + node: ast.Tuple = self.generic_visit(node) + for sub in node.elts: + if sub is Unknown: + return Unknown + return node + + def visit_Dict(self, node: ast.Dict) -> Any: + node: ast.Dict = self.generic_visit(node) + for key in node.keys: + if key is Unknown: + return Unknown + for val in node.values: + if val is Unknown: + return Unknown + return node + + def visit_BoolOp(self, node: ast.BoolOp) -> Any: + node: ast.BoolOp = self.generic_visit(node) + node_values = list(filter(lambda x: x is not Unknown, node.values)) + if isinstance(node.op, ast.And): + if len(node_values) == 1: + return node_values[0] + elif len(node_values) == 0: + return Unknown + else: + if len(node_values) < len(node.values): + return Unknown + return ast.BoolOp(op=node.op, values=node_values) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: + node: ast.UnaryOp = self.generic_visit(node) + if node.operand is Unknown: + return Unknown + return node + + def visit_BinOp(self, node: ast.BinOp) -> Any: + node: ast.BinOp = self.generic_visit(node) + if node.left is Unknown or node.right is Unknown: + return Unknown + return node + + def visit_IfExp(self, node: ast.IfExp) -> Any: + node: ast.IfExp = self.generic_visit(node) + if node.test is Unknown or node.body is Unknown or node.orelse is Unknown: + return Unknown + return node + + def visit_Attribute(self, node: ast.Attribute) -> Any: + node: ast.Attribute = self.generic_visit(node) + if node.value is Unknown: + return Unknown + return node + + def visit_Call(self, node: ast.Call) -> Any: + node: ast.Call = self.generic_visit(node) + if node.func is Unknown: + return Unknown + for arg in node.args: + if arg is Unknown: + return Unknown + for kwarg in node.keywords: + if kwarg is Unknown: + return Unknown + return node + + def visit_Subscript(self, node: ast.Subscript) -> Any: + node: ast.Subscript = self.generic_visit(node) + if node.value is Unknown or node.slice is Unknown: + return Unknown + return node + + def visit_Slice(self, node: ast.Slice) -> Any: + node: ast.Slice = self.generic_visit(node) + if node.lower is Unknown or node.upper is Unknown or node.step is Unknown: + return Unknown + return node diff --git a/aim/sdk/sequence_collection.py b/aim/sdk/sequence_collection.py index 5738a8e28..3c4699bc2 100644 --- a/aim/sdk/sequence_collection.py +++ b/aim/sdk/sequence_collection.py @@ -3,17 +3,20 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Iterator +from tqdm import tqdm + +from aim.sdk.query_analyzer import QueryExpressionTransformer from aim.sdk.query_utils import RunView, SequenceView from aim.sdk.sequence import Sequence from aim.sdk.types import QueryReportMode from aim.storage.query import RestrictedPythonQuery -from tqdm import tqdm if TYPE_CHECKING: + from pandas import DataFrame + from aim.sdk.repo import Repo from aim.sdk.run import Run - from pandas import DataFrame logger = logging.getLogger(__name__) @@ -170,20 +173,33 @@ def iter_runs(self) -> Iterator['SequenceCollection']: if self.report_mode == QueryReportMode.PROGRESS_BAR: progress_bar = tqdm(total=total_runs) + seq_var = self.seq_cls.sequence_name() + t = QueryExpressionTransformer(var_names=[seq_var, ]) + run_expr, is_transformed = t.transform(self.query) + run_query = RestrictedPythonQuery(run_expr) + for run in runs_iterator: - seq_collection = SingleRunSequenceCollection( - run, - self.seq_cls, - self.query, - runs_proxy_cache=self.runs_proxy_cache, - timezone_offset=self._timezone_offset, - ) - if self.report_mode == QueryReportMode.PROGRESS_TUPLE: - yield seq_collection, (runs_counter, total_runs) - else: - if self.report_mode == QueryReportMode.PROGRESS_BAR: - progress_bar.update(1) - yield seq_collection + check_run_sequences = True + if is_transformed: + run_view = RunView(run, runs_proxy_cache=self.runs_proxy_cache, timezone_offset=self._timezone_offset) + match = run_query.check(**{'run': run_view}) + if not match: + check_run_sequences = False + + if check_run_sequences: + seq_collection = SingleRunSequenceCollection( + run, + self.seq_cls, + self.query, + runs_proxy_cache=self.runs_proxy_cache, + timezone_offset=self._timezone_offset, + ) + if self.report_mode == QueryReportMode.PROGRESS_TUPLE: + yield seq_collection, (runs_counter, total_runs) + else: + if self.report_mode == QueryReportMode.PROGRESS_BAR: + progress_bar.update(1) + yield seq_collection runs_counter += 1 def iter(self) -> Iterator[Sequence]: From 17660207610e554776c6ddd945d8133716151421 Mon Sep 17 00:00:00 2001 From: Fabian Keller Date: Thu, 13 Feb 2025 14:33:17 +0100 Subject: [PATCH 08/30] [chore] Bump ruff version from 0.3.3 to 0.9.2 and fix some invalid/dead noqas (#3282) --- aim/acme.py | 3 ++- aim/cli/convert/commands.py | 2 +- aim/cli/init/commands.py | 4 ++-- aim/cli/runs/utils.py | 2 +- aim/cli/server/commands.py | 2 +- aim/cli/up/commands.py | 2 +- aim/ext/notifier/notifier.py | 2 +- aim/ext/sshfs/utils.py | 2 +- aim/ext/tensorboard_tracker/tracker.py | 2 +- aim/ext/transport/heartbeat.py | 6 ++---- aim/ext/transport/message_utils.py | 3 ++- aim/ext/transport/utils.py | 2 +- aim/fastai.py | 2 +- aim/hf_dataset.py | 2 +- aim/hugging_face.py | 2 +- aim/keras.py | 3 ++- aim/keras_tuner.py | 2 +- aim/mxnet.py | 2 +- aim/optuna.py | 2 +- aim/paddle.py | 2 +- aim/prophet.py | 2 +- aim/pytorch.py | 3 ++- aim/pytorch_ignite.py | 2 +- aim/pytorch_lightning.py | 2 +- aim/sb3.py | 2 +- aim/sdk/adapters/fastai.py | 8 ++++++-- aim/sdk/adapters/keras.py | 2 +- aim/sdk/adapters/lightgbm.py | 3 +-- aim/sdk/adapters/mxnet.py | 2 +- aim/sdk/adapters/pytorch_ignite.py | 5 ++--- aim/sdk/adapters/xgboost.py | 3 +-- aim/sdk/callbacks/caller.py | 2 +- aim/sdk/objects/io/wavfile.py | 12 ++++++------ aim/sdk/repo.py | 4 ++-- aim/sdk/sequences/figure_sequence.py | 2 +- aim/sdk/types.py | 3 ++- aim/storage/hashing/hashing.py | 2 +- aim/storage/proxy.py | 4 ++-- aim/storage/query.py | 2 +- aim/storage/types.py | 4 +++- aim/tensorflow.py | 3 ++- aim/utils/__init__.py | 3 ++- aim/web/api/dashboard_apps/views.py | 2 +- aim/web/api/dashboards/views.py | 8 ++++---- aim/web/api/experiments/views.py | 2 +- aim/web/api/runs/object_views.py | 6 +++--- aim/web/middlewares/profiler/profiler.py | 2 +- examples/pytorch_lightning_track.py | 2 +- examples/pytorch_track.py | 2 +- examples/pytorch_track_images.py | 2 +- performance_tests/sdk/queries.py | 6 +++--- requirements.dev.txt | 2 +- ruff.toml | 6 +++++- 53 files changed, 88 insertions(+), 76 deletions(-) diff --git a/aim/acme.py b/aim/acme.py index 44884fd7d..30cce3037 100644 --- a/aim/acme.py +++ b/aim/acme.py @@ -1,2 +1,3 @@ # Alias to SDK acme interface -from aim.sdk.adapters.acme import AimCallback, AimWriter # noqa F401 +from aim.sdk.adapters.acme import AimCallback as AimCallback +from aim.sdk.adapters.acme import AimWriter as AimWriter diff --git a/aim/cli/convert/commands.py b/aim/cli/convert/commands.py index 4e23ef806..14160a536 100644 --- a/aim/cli/convert/commands.py +++ b/aim/cli/convert/commands.py @@ -40,7 +40,7 @@ def convert_tensorboard(ctx, logdir, flat, no_cache): @click.option('--flat', '-f', required=False, is_flag=True, default=False) def convert_tensorflow(ctx, logdir, flat): click.secho( - "WARN: Command 'tf' is deprecated and will be removed in future releases," " please use 'tensorboard' instead.", + "WARN: Command 'tf' is deprecated and will be removed in future releases, please use 'tensorboard' instead.", fg='red', ) repo_inst = ctx.obj['repo_inst'] diff --git a/aim/cli/init/commands.py b/aim/cli/init/commands.py index 1ef6fc354..4d538ec84 100644 --- a/aim/cli/init/commands.py +++ b/aim/cli/init/commands.py @@ -20,7 +20,7 @@ def init(repo, yes, skip_if_exists): re_init = False if Repo.exists(repo_path): if yes and skip_if_exists: - raise click.BadParameter('Conflicting init options.' 'Either specify -y/--yes or -s/--skip-if-exists') + raise click.BadParameter('Conflicting init options.Either specify -y/--yes or -s/--skip-if-exists') elif yes: re_init = True elif skip_if_exists: @@ -28,7 +28,7 @@ def init(repo, yes, skip_if_exists): return else: re_init = click.confirm( - 'Aim repository is already initialized. ' 'Do you want to re-initialize to empty Aim repository?' + 'Aim repository is already initialized. Do you want to re-initialize to empty Aim repository?' ) if not re_init: return diff --git a/aim/cli/runs/utils.py b/aim/cli/runs/utils.py index ec4c332f7..f2b64b13d 100644 --- a/aim/cli/runs/utils.py +++ b/aim/cli/runs/utils.py @@ -48,7 +48,7 @@ def upload_repo_runs(buffer: io.BytesIO, bucket_name: str) -> Tuple[bool, str]: import boto3 except ImportError: raise RuntimeError( - "This command requires 'boto3' to be installed. " 'Please install it with command: \n pip install boto3' + "This command requires 'boto3' to be installed. Please install it with command: \n pip install boto3" ) try: diff --git a/aim/cli/server/commands.py b/aim/cli/server/commands.py index b84c1ae68..7c1587021 100644 --- a/aim/cli/server/commands.py +++ b/aim/cli/server/commands.py @@ -95,5 +95,5 @@ def server(host, port, repo, ssl_keyfile, ssl_certfile, base_path, log_level, de ) exec_cmd(cmd, stream_output=True) except ShellCommandException: - click.echo('Failed to run Aim Tracking Server. ' 'Please see the logs above for details.') + click.echo('Failed to run Aim Tracking Server. Please see the logs above for details.') exit(1) diff --git a/aim/cli/up/commands.py b/aim/cli/up/commands.py index 2044c75e4..de8ed008d 100644 --- a/aim/cli/up/commands.py +++ b/aim/cli/up/commands.py @@ -96,7 +96,7 @@ def up( db_cmd = build_db_upgrade_command() exec_cmd(db_cmd, stream_output=True) except ShellCommandException: - click.echo('Failed to initialize Aim DB. ' 'Please see the logs above for details.') + click.echo('Failed to initialize Aim DB. Please see the logs above for details.') return if port == 0: diff --git a/aim/ext/notifier/notifier.py b/aim/ext/notifier/notifier.py index 1b237af62..ca73fe59a 100644 --- a/aim/ext/notifier/notifier.py +++ b/aim/ext/notifier/notifier.py @@ -34,7 +34,7 @@ def notify(self, message: Optional[str] = None, **kwargs): except Exception as e: attempt += 1 if attempt == self.MAX_RETRIES: - logger.error(f'Notifier {sub} failed to send message "{message}". ' f'No retries left.') + logger.error(f'Notifier {sub} failed to send message "{message}". No retries left.') raise NotificationSendError(e) else: logger.error( diff --git a/aim/ext/sshfs/utils.py b/aim/ext/sshfs/utils.py index 3dba53168..170f9d4fe 100644 --- a/aim/ext/sshfs/utils.py +++ b/aim/ext/sshfs/utils.py @@ -197,7 +197,7 @@ def unmount_remote_repo(mount_point: str, mount_root: str): if exit_code != 0: # in case of failure log warning so the user can unmount manually if needed logger.warning( - f'Could not unmount path: {mount_point}.\n' f'Please unmount manually using command:\n' f'{" ".join(cmd)}' + f'Could not unmount path: {mount_point}.\nPlease unmount manually using command:\n{" ".join(cmd)}' ) else: shutil.rmtree(mount_root) diff --git a/aim/ext/tensorboard_tracker/tracker.py b/aim/ext/tensorboard_tracker/tracker.py index 59af672f8..f902e0806 100644 --- a/aim/ext/tensorboard_tracker/tracker.py +++ b/aim/ext/tensorboard_tracker/tracker.py @@ -31,7 +31,7 @@ def _decode_histogram(value): # This is a bit weird but it seems the histogram counts is usually padded by 0 as tensorboard # only stores the right limits? - # See https://github.com/pytorch/pytorch/blob/7d2a18da0b3427fcbe44b461a0aa508194535885/torch/utils/tensorboard/summary.py#L390 # noqa + # See https://github.com/pytorch/pytorch/blob/7d2a18da0b3427fcbe44b461a0aa508194535885/torch/utils/tensorboard/summary.py#L390 bin_counts = bin_counts[1:] bin_range = (bucket_limits[0], bucket_limits[-1]) diff --git a/aim/ext/transport/heartbeat.py b/aim/ext/transport/heartbeat.py index 3346ccff2..d6390d63e 100644 --- a/aim/ext/transport/heartbeat.py +++ b/aim/ext/transport/heartbeat.py @@ -15,10 +15,8 @@ class HeartbeatSender(object): HEARTBEAT_INTERVAL_DEFAULT = 10 NETWORK_CHECK_INTERVAL = 180 - NETWORK_UNSTABLE_WARNING_TEMPLATE = ( - 'Network connection between client `{}` ' 'and server `{}` appears to be unstable.' - ) - NETWORK_ABSENT_WARNING_TEMPLATE = 'Network connection between client `{}` ' 'and server `{}` appears to be absent.' + NETWORK_UNSTABLE_WARNING_TEMPLATE = 'Network connection between client `{}` and server `{}` appears to be unstable.' + NETWORK_ABSENT_WARNING_TEMPLATE = 'Network connection between client `{}` and server `{}` appears to be absent.' def __init__( self, diff --git a/aim/ext/transport/message_utils.py b/aim/ext/transport/message_utils.py index 127b1f7e5..daa1823b6 100644 --- a/aim/ext/transport/message_utils.py +++ b/aim/ext/transport/message_utils.py @@ -5,7 +5,8 @@ from typing import Iterator, Tuple from aim.storage.object import CustomObject -from aim.storage.treeutils import decode_tree, encode_tree # noqa +from aim.storage.treeutils import decode_tree as decode_tree +from aim.storage.treeutils import encode_tree as encode_tree from aim.storage.types import BLOB diff --git a/aim/ext/transport/utils.py b/aim/ext/transport/utils.py index 037ad1262..b556692fe 100644 --- a/aim/ext/transport/utils.py +++ b/aim/ext/transport/utils.py @@ -13,7 +13,7 @@ def inner(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) - except exc_type as e: # noqa + except exc_type: if error_message is not None: logger.error(error_message) raise RuntimeError(error_message) diff --git a/aim/fastai.py b/aim/fastai.py index 275882894..ab00bee14 100644 --- a/aim/fastai.py +++ b/aim/fastai.py @@ -1,2 +1,2 @@ # Alias to SDK fast.ai interface -from aim.sdk.adapters.fastai import AimCallback # noqa F401 +from aim.sdk.adapters.fastai import AimCallback as AimCallback diff --git a/aim/hf_dataset.py b/aim/hf_dataset.py index e629b15c9..00ecc9cc9 100644 --- a/aim/hf_dataset.py +++ b/aim/hf_dataset.py @@ -1,2 +1,2 @@ # Alias to SDK Hugging Face Datasets interface -from aim.sdk.objects.plugins.hf_datasets_metadata import HFDataset # noqa F401 +from aim.sdk.objects.plugins.hf_datasets_metadata import HFDataset as HFDataset diff --git a/aim/hugging_face.py b/aim/hugging_face.py index 9fbde32ec..692ec2486 100644 --- a/aim/hugging_face.py +++ b/aim/hugging_face.py @@ -1,2 +1,2 @@ # Alias to SDK Hugging Face interface -from aim.sdk.adapters.hugging_face import AimCallback # noqa F401 +from aim.sdk.adapters.hugging_face import AimCallback as AimCallback diff --git a/aim/keras.py b/aim/keras.py index 3383dff65..e1c6ed28f 100644 --- a/aim/keras.py +++ b/aim/keras.py @@ -1,2 +1,3 @@ # Alias to SDK Keras interface -from aim.sdk.adapters.keras import AimCallback, AimTracker # noqa F401 +from aim.sdk.adapters.keras import AimCallback as AimCallback +from aim.sdk.adapters.keras import AimTracker as AimTracker diff --git a/aim/keras_tuner.py b/aim/keras_tuner.py index 5f6577cae..5d264e64d 100644 --- a/aim/keras_tuner.py +++ b/aim/keras_tuner.py @@ -1,2 +1,2 @@ # Alias to SDK Keras-Tuner interface -from aim.sdk.adapters.keras_tuner import AimCallback # noqa F401 +from aim.sdk.adapters.keras_tuner import AimCallback as AimCallback diff --git a/aim/mxnet.py b/aim/mxnet.py index 403d33d40..ceacfb118 100644 --- a/aim/mxnet.py +++ b/aim/mxnet.py @@ -1,2 +1,2 @@ # Alias to SDK mxnet interface -from aim.sdk.adapters.mxnet import AimLoggingHandler # noqa F401 +from aim.sdk.adapters.mxnet import AimLoggingHandler as AimLoggingHandler diff --git a/aim/optuna.py b/aim/optuna.py index 5069d2469..28d0b1dbf 100644 --- a/aim/optuna.py +++ b/aim/optuna.py @@ -1,2 +1,2 @@ # Alias to SDK Optuna interface -from aim.sdk.adapters.optuna import AimCallback # noqa F401 +from aim.sdk.adapters.optuna import AimCallback as AimCallback diff --git a/aim/paddle.py b/aim/paddle.py index 0c4948641..9069d936a 100644 --- a/aim/paddle.py +++ b/aim/paddle.py @@ -1,2 +1,2 @@ # Alias to SDK PaddlePaddle interface -from aim.sdk.adapters.paddle import AimCallback # noqa F401 +from aim.sdk.adapters.paddle import AimCallback as AimCallback diff --git a/aim/prophet.py b/aim/prophet.py index 1a43316f4..661e95cd4 100644 --- a/aim/prophet.py +++ b/aim/prophet.py @@ -1,2 +1,2 @@ # Alias to SDK Prophet interface -from aim.sdk.adapters.prophet import AimLogger # noqa F401 +from aim.sdk.adapters.prophet import AimLogger as AimLogger diff --git a/aim/pytorch.py b/aim/pytorch.py index c493b7a84..677a68f88 100644 --- a/aim/pytorch.py +++ b/aim/pytorch.py @@ -1,2 +1,3 @@ # Alias to SDK PyTorch utils -from aim.sdk.adapters.pytorch import track_params_dists, track_gradients_dists # noqa +from aim.sdk.adapters.pytorch import track_gradients_dists as track_gradients_dists +from aim.sdk.adapters.pytorch import track_params_dists as track_params_dists diff --git a/aim/pytorch_ignite.py b/aim/pytorch_ignite.py index 08cd67ce7..2189c6ddf 100644 --- a/aim/pytorch_ignite.py +++ b/aim/pytorch_ignite.py @@ -1,2 +1,2 @@ # Alias to SDK PyTorch Ignite interface -from aim.sdk.adapters.pytorch_ignite import AimLogger # noqa F401 +from aim.sdk.adapters.pytorch_ignite import AimLogger as AimLogger diff --git a/aim/pytorch_lightning.py b/aim/pytorch_lightning.py index 50d10c1aa..b9a3405f9 100644 --- a/aim/pytorch_lightning.py +++ b/aim/pytorch_lightning.py @@ -1,2 +1,2 @@ # Alias to SDK PyTorch Lightning interface -from aim.sdk.adapters.pytorch_lightning import AimLogger # noqa F401 +from aim.sdk.adapters.pytorch_lightning import AimLogger as AimLogger diff --git a/aim/sb3.py b/aim/sb3.py index 43fd7899e..78bdec8ee 100644 --- a/aim/sb3.py +++ b/aim/sb3.py @@ -1,2 +1,2 @@ # Alias to SDK sb3 interface -from aim.sdk.adapters.sb3 import AimCallback # noqa F401 +from aim.sdk.adapters.sb3 import AimCallback as AimCallback diff --git a/aim/sdk/adapters/fastai.py b/aim/sdk/adapters/fastai.py index 37390444c..88b7c4fdd 100644 --- a/aim/sdk/adapters/fastai.py +++ b/aim/sdk/adapters/fastai.py @@ -11,7 +11,7 @@ from fastcore.basics import detuplify, ignore_exceptions, store_attr except ImportError: raise RuntimeError( - 'This contrib module requires fastai to be installed. ' 'Please install it with command: \n pip install fastai' + 'This contrib module requires fastai to be installed. Please install it with command: \n pip install fastai' ) logger = getLogger(__name__) @@ -107,7 +107,11 @@ def gather_args(self): args['n_inp'] = n_inp xb = self.dls.valid.one_batch()[:n_inp] args.update( - {f'input {n+1} dim {i+1}': d for n in range(n_inp) for i, d in enumerate(list(detuplify(xb[n]).shape))} + { + f'input {n + 1} dim {i + 1}': d + for n in range(n_inp) + for i, d in enumerate(list(detuplify(xb[n]).shape)) + } ) except Exception: logger.warning('Failed to gather input dimensions') diff --git a/aim/sdk/adapters/keras.py b/aim/sdk/adapters/keras.py index 4a2141249..10af8b711 100644 --- a/aim/sdk/adapters/keras.py +++ b/aim/sdk/adapters/keras.py @@ -9,7 +9,7 @@ from keras.callbacks import Callback except ImportError: raise RuntimeError( - 'This contrib module requires keras to be installed. ' 'Please install it with command: \n pip install keras' + 'This contrib module requires keras to be installed. Please install it with command: \n pip install keras' ) diff --git a/aim/sdk/adapters/lightgbm.py b/aim/sdk/adapters/lightgbm.py index f2bae4e16..f006cd971 100644 --- a/aim/sdk/adapters/lightgbm.py +++ b/aim/sdk/adapters/lightgbm.py @@ -8,8 +8,7 @@ from lightgbm.callback import CallbackEnv except ImportError: raise RuntimeError( - 'This contrib module requires Lightgbm to be installed. ' - 'Please install it with command: \n pip install lightgbm' + 'This contrib module requires Lightgbm to be installed. Please install it with command: \n pip install lightgbm' ) diff --git a/aim/sdk/adapters/mxnet.py b/aim/sdk/adapters/mxnet.py index e10d4a19c..88f005dd8 100644 --- a/aim/sdk/adapters/mxnet.py +++ b/aim/sdk/adapters/mxnet.py @@ -75,7 +75,7 @@ def train_begin(self, estimator: Optional[Estimator], *args, **kwargs): optimizer = trainer.optimizer.__class__.__name__ lr = trainer.learning_rate - estimator.logger.info('Training begin: using optimizer %s ' 'with current learning rate %.4f ', optimizer, lr) + estimator.logger.info('Training begin: using optimizer %s with current learning rate %.4f ', optimizer, lr) if estimator.max_epoch: estimator.logger.info('Train for %d epochs.', estimator.max_epoch) else: diff --git a/aim/sdk/adapters/pytorch_ignite.py b/aim/sdk/adapters/pytorch_ignite.py index 42cf7d0f2..6a9506c54 100644 --- a/aim/sdk/adapters/pytorch_ignite.py +++ b/aim/sdk/adapters/pytorch_ignite.py @@ -8,7 +8,7 @@ from torch.optim import Optimizer except ImportError: raise RuntimeError( - 'This contrib module requires PyTorch to be installed. ' 'Please install it with command: \n pip install torch' + 'This contrib module requires PyTorch to be installed. Please install it with command: \n pip install torch' ) try: from ignite.contrib.handlers.base_logger import ( @@ -185,8 +185,7 @@ def __call__(self, engine: Engine, logger: AimLogger, event_name: Union[str, Eve if not isinstance(global_step, int): raise TypeError( - f'global_step must be int, got {type(global_step)}.' - ' Please check the output of global_step_transform.' + f'global_step must be int, got {type(global_step)}. Please check the output of global_step_transform.' ) metrics = {} diff --git a/aim/sdk/adapters/xgboost.py b/aim/sdk/adapters/xgboost.py index 8d9926287..832110f25 100644 --- a/aim/sdk/adapters/xgboost.py +++ b/aim/sdk/adapters/xgboost.py @@ -8,8 +8,7 @@ from xgboost.callback import TrainingCallback except ImportError: raise RuntimeError( - 'This contrib module requires XGBoost to be installed. ' - 'Please install it with command: \n pip install xgboost' + 'This contrib module requires XGBoost to be installed. Please install it with command: \n pip install xgboost' ) diff --git a/aim/sdk/callbacks/caller.py b/aim/sdk/callbacks/caller.py index 6ac0c29ae..387406e22 100644 --- a/aim/sdk/callbacks/caller.py +++ b/aim/sdk/callbacks/caller.py @@ -42,7 +42,7 @@ def trigger(self, event_name: str, **kwargs): for handler in handlers: try: handler(**all_kwargs) - except Exception: # noqa + except Exception: # TODO catch errors on handler invocation (nice-to-have) logger.warning(f"Failed to run callback '{handler.__name__}'.") logger.warning(traceback.format_exc()) diff --git a/aim/sdk/objects/io/wavfile.py b/aim/sdk/objects/io/wavfile.py index 5c58daf4a..34d187c7a 100644 --- a/aim/sdk/objects/io/wavfile.py +++ b/aim/sdk/objects/io/wavfile.py @@ -316,7 +316,7 @@ def _raise_bad_format(format_tag): except ValueError: format_name = f'{format_tag:#06x}' raise ValueError( - f"Unknown wave file format: {format_name}. Supported formats: {', '.join(x.name for x in KNOWN_WAVE_FORMATS)}" + f'Unknown wave file format: {format_name}. Supported formats: {", ".join(x.name for x in KNOWN_WAVE_FORMATS)}' ) @@ -447,12 +447,12 @@ def _read_data_chunk(fid, format_tag, channels, bit_depth, is_big_endian, block_ # Remaining bit depths can map directly to signed numpy dtypes dtype = f'{fmt}i{bytes_per_sample}' else: - raise ValueError('Unsupported bit depth: the WAV file ' f'has {bit_depth}-bit integer data.') + raise ValueError(f'Unsupported bit depth: the WAV file has {bit_depth}-bit integer data.') elif format_tag == WAVE_FORMAT.IEEE_FLOAT: if bit_depth in {32, 64}: dtype = f'{fmt}f{bytes_per_sample}' else: - raise ValueError('Unsupported bit depth: the WAV file ' f'has {bit_depth}-bit floating-point data.') + raise ValueError(f'Unsupported bit depth: the WAV file has {bit_depth}-bit floating-point data.') else: _raise_bad_format(format_tag) @@ -480,7 +480,7 @@ def _read_data_chunk(fid, format_tag, channels, bit_depth, is_big_endian, block_ data = numpy.memmap(fid, dtype=dtype, mode='c', offset=start, shape=(n_samples,)) fid.seek(start + size) else: - raise ValueError('mmap=True not compatible with ' f'{bytes_per_sample}-byte container size.') + raise ValueError(f'mmap=True not compatible with {bytes_per_sample}-byte container size.') _handle_pad_byte(fid, size) @@ -516,7 +516,7 @@ def _read_riff_chunk(fid): fmt = '>I' else: # There are also .wav files with "FFIR" or "XFIR" signatures? - raise ValueError(f'File format {repr(str1)} not understood. Only ' "'RIFF' and 'RIFX' supported.") + raise ValueError(f"File format {repr(str1)} not understood. Only 'RIFF' and 'RIFX' supported.") # Size of entire file file_size = struct.unpack(fmt, fid.read(4))[0] + 8 @@ -554,7 +554,7 @@ def read(buffer, mmap=False): if data_chunk_received: # End of file but data successfully read warnings.warn( - 'Reached EOF prematurely; finished at {:d} bytes, ' 'expected {:d} bytes from header.'.format( + 'Reached EOF prematurely; finished at {:d} bytes, expected {:d} bytes from header.'.format( fid.tell(), file_size ), WavFileWarning, diff --git a/aim/sdk/repo.py b/aim/sdk/repo.py index 992794964..6d2471cb2 100644 --- a/aim/sdk/repo.py +++ b/aim/sdk/repo.py @@ -985,7 +985,7 @@ def _backup_run(self, run_hash): from aim.sdk.utils import backup_run if self.is_remote_repo: - self._remote_repo_proxy._restore_run(run_hash) # noqa + self._remote_repo_proxy._restore_run(run_hash) else: backup_run(self, run_hash) @@ -993,7 +993,7 @@ def _restore_run(self, run_hash): from aim.sdk.utils import restore_run_backup if self.is_remote_repo: - self._remote_repo_proxy._restore_run(run_hash) # noqa + self._remote_repo_proxy._restore_run(run_hash) else: restore_run_backup(self, run_hash) diff --git a/aim/sdk/sequences/figure_sequence.py b/aim/sdk/sequences/figure_sequence.py index ff6081e60..885828f79 100644 --- a/aim/sdk/sequences/figure_sequence.py +++ b/aim/sdk/sequences/figure_sequence.py @@ -9,7 +9,7 @@ class Figures(Sequence): @classmethod def allowed_dtypes(cls) -> Union[str, Tuple[str, ...]]: - return (Figure.get_typename(),) # noqa : need a tuple for consitancy + return (Figure.get_typename(),) # need a tuple for consitancy @classmethod def sequence_name(cls) -> str: diff --git a/aim/sdk/types.py b/aim/sdk/types.py index 51fdc72cd..aa70e24ec 100644 --- a/aim/sdk/types.py +++ b/aim/sdk/types.py @@ -1,6 +1,7 @@ -from aim.storage.types import * # noqa F401 from enum import Enum +from aim.storage.types import * # noqa: F403 + class QueryReportMode(Enum): DISABLED = 0 diff --git a/aim/storage/hashing/hashing.py b/aim/storage/hashing/hashing.py index 1aaa7e52e..eef53c5d2 100644 --- a/aim/storage/hashing/hashing.py +++ b/aim/storage/hashing/hashing.py @@ -11,7 +11,7 @@ from typing import Tuple, Union -from aim.storage.encoding import decode_int64, encode_int64 # noqa +from aim.storage.encoding import decode_int64, encode_int64 from aim.storage.hashing import c_hash from aim.storage.types import ( AimObject, diff --git a/aim/storage/proxy.py b/aim/storage/proxy.py index 8d967837e..cf3d84fca 100644 --- a/aim/storage/proxy.py +++ b/aim/storage/proxy.py @@ -192,8 +192,8 @@ def __name__(self, value): def __class__(self): return self.__wrapped__().__class__ - @__class__.setter # noqa - def __class__(self, value): # noqa + @__class__.setter + def __class__(self, value): self.__wrapped__().__class__ = value @property diff --git a/aim/storage/query.py b/aim/storage/query.py index 0ada6f153..82de23657 100644 --- a/aim/storage/query.py +++ b/aim/storage/query.py @@ -52,7 +52,7 @@ def safer_getattr(object, name, default=None, getattr=getattr): if name == 'format' and isinstance(object, str): raise NotImplementedError('Using format() on a %s is not safe.' % object.__class__.__name__) if name[0] == '_': - raise AttributeError('"{name}" is an invalid attribute name because it ' 'starts with "_"'.format(name=name)) + raise AttributeError('"{name}" is an invalid attribute name because it starts with "_"'.format(name=name)) val = getattr(object, name, default) return val diff --git a/aim/storage/types.py b/aim/storage/types.py index a21caa061..6fbf6e012 100644 --- a/aim/storage/types.py +++ b/aim/storage/types.py @@ -1,6 +1,8 @@ -from aim.storage.utils import BLOB # noqa F401 from typing import Dict, List, Tuple, Union +from aim.storage.utils import BLOB as BLOB + + NoneType = type(None) diff --git a/aim/tensorflow.py b/aim/tensorflow.py index 93ccaed16..17cee6549 100644 --- a/aim/tensorflow.py +++ b/aim/tensorflow.py @@ -1,2 +1,3 @@ # Alias to SDK TensorFlow Keras interface -from aim.sdk.adapters.tensorflow import AimCallback, AimTracker # noqa F401 +from aim.sdk.adapters.tensorflow import AimCallback as AimCallback +from aim.sdk.adapters.tensorflow import AimTracker as AimTracker diff --git a/aim/utils/__init__.py b/aim/utils/__init__.py index 761d0cd34..c48598750 100644 --- a/aim/utils/__init__.py +++ b/aim/utils/__init__.py @@ -1 +1,2 @@ -from aim.ext.exception_resistant import enable_safe_mode, disable_safe_mode # noqa +from aim.ext.exception_resistant import disable_safe_mode as disable_safe_mode +from aim.ext.exception_resistant import enable_safe_mode as enable_safe_mode diff --git a/aim/web/api/dashboard_apps/views.py b/aim/web/api/dashboard_apps/views.py index 50fb87123..7f1acae26 100644 --- a/aim/web/api/dashboard_apps/views.py +++ b/aim/web/api/dashboard_apps/views.py @@ -19,7 +19,7 @@ @dashboard_apps_router.get('/', response_model=ExploreStateListOut) async def dashboard_apps_list_api(session: Session = Depends(get_session)): - explore_states = session.query(ExploreState).filter(ExploreState.is_archived == False) # noqa + explore_states = session.query(ExploreState).filter(ExploreState.is_archived == False) # noqa: E712 result = [] for es in explore_states: result.append(explore_state_response_serializer(es)) diff --git a/aim/web/api/dashboards/views.py b/aim/web/api/dashboards/views.py index fe12bf69d..5cd16c302 100644 --- a/aim/web/api/dashboards/views.py +++ b/aim/web/api/dashboards/views.py @@ -19,7 +19,7 @@ @dashboards_router.get('/', response_model=List[DashboardOut]) async def dashboards_list_api(session: Session = Depends(get_session)): - dashboards_query = session.query(Dashboard).filter(Dashboard.is_archived == False).order_by(Dashboard.updated_at) # noqa + dashboards_query = session.query(Dashboard).filter(Dashboard.is_archived == False).order_by(Dashboard.updated_at) # noqa: E712 result = [] for dashboard in dashboards_query: @@ -50,7 +50,7 @@ async def dashboards_post_api(request_data: DashboardCreateIn, session: Session @dashboards_router.get('/{dashboard_id}/', response_model=DashboardOut) async def dashboards_get_api(dashboard_id: str, session: Session = Depends(get_session)): - dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa + dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa: E712 if not dashboard: raise HTTPException(status_code=404) @@ -61,7 +61,7 @@ async def dashboards_get_api(dashboard_id: str, session: Session = Depends(get_s async def dashboards_put_api( dashboard_id: str, request_data: DashboardUpdateIn, session: Session = Depends(get_session) ): - dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa + dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa: E712 if not dashboard: raise HTTPException(status_code=404) dashboard_name = request_data.name @@ -77,7 +77,7 @@ async def dashboards_put_api( @dashboards_router.delete('/{dashboard_id}/') async def dashboards_delete_api(dashboard_id: str, session: Session = Depends(get_session)): - dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa + dashboard = session.query(Dashboard).filter(Dashboard.uuid == dashboard_id, Dashboard.is_archived == False).first() # noqa: E712 if not dashboard: raise HTTPException(status_code=404) diff --git a/aim/web/api/experiments/views.py b/aim/web/api/experiments/views.py index 9164eb78f..a47511cd9 100644 --- a/aim/web/api/experiments/views.py +++ b/aim/web/api/experiments/views.py @@ -114,7 +114,7 @@ async def update_experiment_properties_api(exp_id: str, exp_in: ExperimentUpdate if exp_in.archived is not None: if exp_in.archived and len(exp.runs) > 0: raise HTTPException( - status_code=400, detail=(f"Cannot archive experiment '{exp_id}'. " 'Experiment has associated runs.') + status_code=400, detail=(f"Cannot archive experiment '{exp_id}'. Experiment has associated runs.") ) exp.archived = exp_in.archived diff --git a/aim/web/api/runs/object_views.py b/aim/web/api/runs/object_views.py index 8e5d5b3f5..6a4b4d0e9 100644 --- a/aim/web/api/runs/object_views.py +++ b/aim/web/api/runs/object_views.py @@ -32,7 +32,7 @@ class CustomObjectApiConfig: sequence_type: type = Sequence resolve_blobs: bool = False - dump_record_fn: callable = lambda x: x.data # noqa E731 + dump_record_fn: callable = lambda x: x.data model: type = BaseModel @staticmethod @@ -165,7 +165,7 @@ class TextApiConfig(CustomObjectApiConfig): class DistributionApiConfig(CustomObjectApiConfig): sequence_type = Distributions resolve_blobs = True - dump_record_fn = lambda x: numpy_to_encodable(x.weights) # noqa E731 + dump_record_fn = lambda x: numpy_to_encodable(x.weights) # noqa: E731 model = DistributionInfo @@ -178,5 +178,5 @@ class AudioApiConfig(CustomObjectApiConfig): class FigureApiConfig(CustomObjectApiConfig): sequence_type = Figures resolve_blobs = True - dump_record_fn = lambda x: x.data # noqa E731 + dump_record_fn = lambda x: x.data # noqa: E731 model = FigureInfo diff --git a/aim/web/middlewares/profiler/profiler.py b/aim/web/middlewares/profiler/profiler.py index 0956dc8b5..fc39f4006 100644 --- a/aim/web/middlewares/profiler/profiler.py +++ b/aim/web/middlewares/profiler/profiler.py @@ -61,7 +61,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: profiler = self.profiler(interval=self._profiler_interval) try: profiler.start() - except: # noqa + except: # noqa: E722 skip_profiling = True else: skip_profiling = False diff --git a/examples/pytorch_lightning_track.py b/examples/pytorch_lightning_track.py index 9d9d8c98b..3a4b23c7c 100644 --- a/examples/pytorch_lightning_track.py +++ b/examples/pytorch_lightning_track.py @@ -4,7 +4,7 @@ if importlib.util.find_spec('lightning'): import lightning.pytorch as pl -elif importlib.util.find_spec('pytorch_lightning'): # noqa F401 +elif importlib.util.find_spec('pytorch_lightning'): # F401 import pytorch_lightning as pl else: raise RuntimeError( diff --git a/examples/pytorch_track.py b/examples/pytorch_track.py index 3927356cb..3c68bd51d 100644 --- a/examples/pytorch_track.py +++ b/examples/pytorch_track.py @@ -90,7 +90,7 @@ def forward(self, x): if i % 30 == 0: logging.info( - 'Epoch [{}/{}], Step [{}/{}], ' 'Loss: {:.4f}'.format( + 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch + 1, num_epochs, i + 1, total_step, loss.item() ) ) diff --git a/examples/pytorch_track_images.py b/examples/pytorch_track_images.py index adb693a2f..bcf4f627d 100644 --- a/examples/pytorch_track_images.py +++ b/examples/pytorch_track_images.py @@ -94,7 +94,7 @@ def forward(self, x): if i % 30 == 0: logging.info( - 'Epoch [{}/{}], Step [{}/{}], ' 'Loss: {:.4f}'.format( + 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch + 1, num_epochs, i + 1, total_step, loss.item() ) ) diff --git a/performance_tests/sdk/queries.py b/performance_tests/sdk/queries.py index 514ce5583..38b79bf3d 100644 --- a/performance_tests/sdk/queries.py +++ b/performance_tests/sdk/queries.py @@ -1,7 +1,7 @@ -query_0 = 'run.hparams.benchmark == "glue" ' 'and run.hparams.dataset == "cola" ' 'and metric.context.subset != "train"' -query_1 = 'run.hparams.benchmark == "glue" ' 'and run.hparams.dataset == "cola"' +query_0 = 'run.hparams.benchmark == "glue" and run.hparams.dataset == "cola" and metric.context.subset != "train"' +query_1 = 'run.hparams.benchmark == "glue" and run.hparams.dataset == "cola"' query_2 = 'run.hparams.benchmark == "glue"' -query_3 = 'run.hparams.dataset == "cola" ' 'and run.experiment.name != "baseline-warp_4-cola"' +query_3 = 'run.hparams.dataset == "cola" and run.experiment.name != "baseline-warp_4-cola"' queries = { diff --git a/requirements.dev.txt b/requirements.dev.txt index 936433a7b..b36a02f59 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,3 +1,3 @@ wheel >= 0.31.0 twine >= 1.11.0 -ruff == 0.3.3 +ruff == 0.9.2 diff --git a/ruff.toml b/ruff.toml index d164ced82..e23481fd0 100644 --- a/ruff.toml +++ b/ruff.toml @@ -11,7 +11,11 @@ exclude = [ inline-quotes = "single" [lint] -extend-select = ["I"] +extend-select = [ + "I", + "PGH004", # blanket-noqa + "RUF100", # unused-noqa +] [lint.isort] no-lines-before = ["future", "standard-library", "first-party"] From db6fcc1e09351811682e5a58f8250e5bc9af2bc7 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Tue, 18 Feb 2025 16:29:53 +0400 Subject: [PATCH 09/30] [fix] Move performance tests to local mac mini (#3290) --- .github/workflows/pull-request.yml | 4 +-- performance_tests/BASELINE | 42 +++++++++++++++--------------- performance_tests/conftest.py | 2 +- tests/requirements.txt | 1 + 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 348ab9349..46fc96103 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -68,8 +68,8 @@ jobs: storage-performance-checks: needs: run-checks - concurrency: perf-tests - runs-on: [self-hosted, performance-tests] + concurrency: storage-performance-checks + runs-on: [self-hosted, perf-tests] name: Performance tests steps: - name: checkout diff --git a/performance_tests/BASELINE b/performance_tests/BASELINE index 1f403fadb..ebe22e938 100644 --- a/performance_tests/BASELINE +++ b/performance_tests/BASELINE @@ -1,21 +1,21 @@ -test_collect_metrics_data_0 1.8545958518981933 -test_collect_metrics_data_1 1.9959398269653321 -test_collect_metrics_data_2 10.835494375228881 -test_collect_metrics_data_3 1.8672633171081543 -test_collect_runs_data_0 0.8988437175750732 -test_collect_runs_data_1 1.039186429977417 -test_collect_runs_data_2 3.469265604019165 -test_collect_runs_data_3 0.9486905574798584 -test_query_metrics_0 1.6766140460968018 -test_query_metrics_1 1.6763684749603271 -test_query_metrics_2 1.6051365375518798 -test_query_metrics_3 1.5391615390777589 -test_query_runs_0 0.8991998672485352 -test_query_runs_1 0.9259328842163086 -test_query_runs_2 0.839762544631958 -test_query_runs_3 0.832861852645874 -test_container_open 0.1440361499786377 -test_iterative_access 4.000607919692993 -test_random_access_0 0.663770055770874 -test_random_access_1 1.4745195388793946 -test_random_access_2 2.424658107757568 \ No newline at end of file +test_collect_metrics_data_0 0.3717397689819336 +test_collect_metrics_data_1 0.3963047981262207 +test_collect_metrics_data_2 2.7405614376068117 +test_collect_metrics_data_3 0.3710219860076904 +test_collect_runs_data_0 0.17322354316711425 +test_collect_runs_data_1 0.20246338844299316 +test_collect_runs_data_2 0.7970072269439697 +test_collect_runs_data_3 0.1911233901977539 +test_query_metrics_0 0.311903190612793 +test_query_metrics_1 0.3122593879699707 +test_query_metrics_2 0.3092495441436768 +test_query_metrics_3 0.288785982131958 +test_query_runs_0 0.17433061599731445 +test_query_runs_1 0.17484822273254394 +test_query_runs_2 0.17181901931762694 +test_query_runs_3 0.1616499423980713 +test_container_open 0.04026708602905273 +test_iterative_access 1.1857992172241212 +test_random_access_0 0.14068403244018554 +test_random_access_1 0.26419754028320314 +test_random_access_2 0.3941319942474365 \ No newline at end of file diff --git a/performance_tests/conftest.py b/performance_tests/conftest.py index a53afa2e7..95fa82005 100644 --- a/performance_tests/conftest.py +++ b/performance_tests/conftest.py @@ -43,7 +43,7 @@ def pytest_sessionstart(session): _init_test_repos() else: # github actions performance tests on self hosted runner - os.chdir('/home/ubuntu/performance_logs/') + os.chdir('/Users/github/workers/perf-tests/actions-runner/_work/performance_logs') time.sleep(10) diff --git a/tests/requirements.txt b/tests/requirements.txt index e3135d0bf..f2e592cb9 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,6 +2,7 @@ torch tensorflow deeplake<4.0.0 # update when proper documentation is available +azure-storage-blob # for deeplake # hub fastapi>=0.87.0 httpx From 07338cae07b5e119777cabd3d976d56187178be7 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Mon, 24 Feb 2025 23:03:22 +0400 Subject: [PATCH 10/30] [fix] Resolve session refresh issues when db file is replaced (#3294) --- aim/cli/storage/commands.py | 2 +- aim/sdk/query_analyzer.py | 13 ++++++++++++- aim/sdk/run.py | 2 +- aim/sdk/sequence_collection.py | 12 +++++++----- aim/storage/structured/db.py | 4 ++-- aim/utils/deprecation.py | 8 ++++---- aim/web/api/db.py | 7 ++++--- setup.py | 4 ++++ 8 files changed, 35 insertions(+), 17 deletions(-) diff --git a/aim/cli/storage/commands.py b/aim/cli/storage/commands.py index 60ec26614..3210f7a69 100644 --- a/aim/cli/storage/commands.py +++ b/aim/cli/storage/commands.py @@ -56,7 +56,7 @@ def to_3_11(ctx, hashes, yes): try: run = Run(run_hash, repo=repo) if run.check_metrics_version(): - backup_run(run) + backup_run(repo, run.hash) run.update_metrics() index_manager.index(run_hash) else: diff --git a/aim/sdk/query_analyzer.py b/aim/sdk/query_analyzer.py index 1930bcf78..7c7065a24 100644 --- a/aim/sdk/query_analyzer.py +++ b/aim/sdk/query_analyzer.py @@ -10,6 +10,17 @@ class Unknown(ast.AST): Unknown = Unknown() # create a single instance of value node +def unparse(*args, **kwargs): + import sys + + if sys.version_info.minor < 9: + import astunparse + + return astunparse.unparse(*args, **kwargs) + else: + return ast.unparse(*args, **kwargs) + + class QueryExpressionTransformer(ast.NodeTransformer): def __init__(self, *, var_names: List[str]): self._var_names = var_names @@ -20,7 +31,7 @@ def transform(self, expr: str) -> Tuple[str, bool]: if transformed is Unknown: return expr, False else: - return ast.unparse(transformed), True + return unparse(transformed), True def visit_Expression(self, node: ast.Expression) -> Any: node: ast.Expression = self.generic_visit(node) diff --git a/aim/sdk/run.py b/aim/sdk/run.py index 08b89b8c3..59bc4d806 100644 --- a/aim/sdk/run.py +++ b/aim/sdk/run.py @@ -287,7 +287,7 @@ def __init__( raise RuntimeError else: logger.warning(f'Detected sub-optimal format metrics for Run {self.hash}. Upgrading...') - backup_path = backup_run(self) + backup_path = backup_run(self.repo, self.hash) try: self.update_metrics() logger.warning(f'Successfully converted Run {self.hash}') diff --git a/aim/sdk/sequence_collection.py b/aim/sdk/sequence_collection.py index 3c4699bc2..62c083d45 100644 --- a/aim/sdk/sequence_collection.py +++ b/aim/sdk/sequence_collection.py @@ -3,20 +3,18 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Iterator -from tqdm import tqdm - from aim.sdk.query_analyzer import QueryExpressionTransformer from aim.sdk.query_utils import RunView, SequenceView from aim.sdk.sequence import Sequence from aim.sdk.types import QueryReportMode from aim.storage.query import RestrictedPythonQuery +from tqdm import tqdm if TYPE_CHECKING: - from pandas import DataFrame - from aim.sdk.repo import Repo from aim.sdk.run import Run + from pandas import DataFrame logger = logging.getLogger(__name__) @@ -174,7 +172,11 @@ def iter_runs(self) -> Iterator['SequenceCollection']: progress_bar = tqdm(total=total_runs) seq_var = self.seq_cls.sequence_name() - t = QueryExpressionTransformer(var_names=[seq_var, ]) + t = QueryExpressionTransformer( + var_names=[ + seq_var, + ] + ) run_expr, is_transformed = t.transform(self.query) run_query = RestrictedPythonQuery(run_expr) diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index 3632bafa3..2849b3be4 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -63,7 +63,6 @@ def __init__(self, path: str, readonly: bool = False): self.engine = create_engine( self.db_url, echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))) ) - self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) self._upgraded = None @classmethod @@ -91,7 +90,8 @@ def caches(self): return self._caches def get_session(self, autocommit=True): - session = self.session_cls() + session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) + session = session_cls() setattr(session, 'autocommit', autocommit) return session diff --git a/aim/utils/deprecation.py b/aim/utils/deprecation.py index 46bd86266..11bf5af30 100644 --- a/aim/utils/deprecation.py +++ b/aim/utils/deprecation.py @@ -10,11 +10,11 @@ def python_version_deprecation_check(): import sys version_info = sys.version_info - if version_info.major == 3 and version_info.minor == 6: + if version_info.major == 3 and version_info.minor == 7: deprecation_warning( - remove_version='3.16', - msg='Python 3.6 has reached EOL. Aim support for Python 3.6 is deprecated!', - remove_msg_template='Python 3.6 support will be dropped in', + remove_version='3.17', + msg='Python 3.7 has reached EOL. Aim support for Python 3.7 is deprecated!', + remove_msg_template='Python 3.7 support will be dropped in', ) diff --git a/aim/web/api/db.py b/aim/web/api/db.py index 459fd3112..80aeaa539 100644 --- a/aim/web/api/db.py +++ b/aim/web/api/db.py @@ -15,12 +15,12 @@ echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), connect_args={'check_same_thread': False}, ) -SessionLocal = sessionmaker(autoflush=False, bind=engine) Base = declarative_base() def get_session(): - session = SessionLocal() + session_cls = sessionmaker(autoflush=False, bind=engine) + session = session_cls() try: yield session finally: @@ -29,7 +29,8 @@ def get_session(): @contextmanager def get_contexted_session(): - session = SessionLocal() + session_cls = sessionmaker(autoflush=False, bind=engine) + session = session_cls() try: yield session finally: diff --git a/setup.py b/setup.py index 1cfacd0ca..00725b280 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,9 @@ def package_files(directory): 'boto3', ] +if sys.version_info.minor < 9: + REQUIRED += ['astunparse'] + class UploadCommand(Command): """Support setup.py upload.""" @@ -194,6 +197,7 @@ def cytonize_extensions(): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: Implementation :: PyPy', ], ext_modules=cytonize_extensions(), From fe316ddf7bd4395fcdf6453bc3b70a71b1182c02 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Tue, 4 Mar 2025 19:07:15 +0400 Subject: [PATCH 11/30] [fix] Resolve issue with adding duplicate tags (#3296) --- CHANGELOG.md | 6 ++++++ aim/storage/structured/sql_engine/entities.py | 9 +++++++++ aim/utils/deprecation.py | 2 +- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5aa9e56be..a67e85036 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ ### Fixes: - Decrease client resources keep-alive time (mihran113) - Fix connection of data points on epoch alignment (mihran113) +- Resolve issue with adding duplicate tags to the same run (mihran113) +- Resolve session refresh issues when db file is replaced (mihran113) + +### Enhancements: +- Skip metrics check when run is known to yield false result (alberttorosyan) +- Correct indentation on query proxy object return statement (alberttorosyan) ## 3.27.0 Dec 18, 2024 diff --git a/aim/storage/structured/sql_engine/entities.py b/aim/storage/structured/sql_engine/entities.py index 84c72158c..554d4c70d 100644 --- a/aim/storage/structured/sql_engine/entities.py +++ b/aim/storage/structured/sql_engine/entities.py @@ -1,3 +1,5 @@ +import logging + from typing import Collection, List, Optional, Union import pytz @@ -26,6 +28,9 @@ from sqlalchemy.orm import joinedload +logger = logging.getLogger(__name__) + + def session_commit_or_flush(session): if getattr(session, 'autocommit', True) and sa_version >= '2.0.0': session.commit() @@ -227,6 +232,10 @@ def unsafe_add_tag(): self._model.tags.append(tag) session.add(self._model) + if value in self.tags: + logger.warning(f'Tag with value: {value} is already present in this run.') + return + session = self._session unsafe_add_tag() try: diff --git a/aim/utils/deprecation.py b/aim/utils/deprecation.py index 11bf5af30..b0bf2a70f 100644 --- a/aim/utils/deprecation.py +++ b/aim/utils/deprecation.py @@ -12,7 +12,7 @@ def python_version_deprecation_check(): version_info = sys.version_info if version_info.major == 3 and version_info.minor == 7: deprecation_warning( - remove_version='3.17', + remove_version='3.30', msg='Python 3.7 has reached EOL. Aim support for Python 3.7 is deprecated!', remove_msg_template='Python 3.7 support will be dropped in', ) From b6c0b1f59825b60478aede026cc8f2a91af983f2 Mon Sep 17 00:00:00 2001 From: Maximilian Luz Date: Tue, 11 Mar 2025 16:24:03 +0100 Subject: [PATCH 12/30] [fix] Message stream parsing (#3298) --- CHANGELOG.md | 1 + aim/ext/transport/message_utils.py | 15 +-------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a67e85036..7dcc243e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ### Enhancements: - Skip metrics check when run is known to yield false result (alberttorosyan) - Correct indentation on query proxy object return statement (alberttorosyan) +- Fix spurious assertion error in message stream parsing (qzed) ## 3.27.0 Dec 18, 2024 diff --git a/aim/ext/transport/message_utils.py b/aim/ext/transport/message_utils.py index daa1823b6..cc0422e5a 100644 --- a/aim/ext/transport/message_utils.py +++ b/aim/ext/transport/message_utils.py @@ -46,22 +46,9 @@ def pack_stream(tree: Iterator[Tuple[bytes, bytes]]) -> bytes: yield struct.pack('I', len(key)) + key + struct.pack('?', True) + struct.pack('I', len(val)) + val -def unpack_helper(msg: bytes) -> Tuple[bytes, bytes]: - (key_size,), tail = struct.unpack('I', msg[:4]), msg[4:] - key, tail = tail[:key_size], tail[key_size:] - (is_blob,), tail = struct.unpack('?', tail[:1]), tail[1:] - (value_size,), tail = struct.unpack('I', tail[:4]), tail[4:] - value, tail = tail[:value_size], tail[value_size:] - assert len(tail) == 0 - if is_blob: - yield key, BLOB(data=value) - else: - yield key, value - - def unpack_stream(stream) -> Tuple[bytes, bytes]: for msg in stream: - yield from unpack_helper(msg) + yield from unpack_args(msg) def raise_exception(server_exception): From 3c40c8350e8eb0b06d375870eb1e18e833a071d0 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Thu, 13 Mar 2025 11:45:41 +0400 Subject: [PATCH 13/30] [fix] Handle empty queries (#3299) * [fix] Handle empty queries * [fix] Formatting issues --- aim/sdk/query_analyzer.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/aim/sdk/query_analyzer.py b/aim/sdk/query_analyzer.py index 7c7065a24..f06d589f0 100644 --- a/aim/sdk/query_analyzer.py +++ b/aim/sdk/query_analyzer.py @@ -1,4 +1,5 @@ import ast +import sys from typing import Any, List, Tuple @@ -9,15 +10,14 @@ class Unknown(ast.AST): Unknown = Unknown() # create a single instance of value node +if sys.version_info.minor < 9: + import astunparse -def unparse(*args, **kwargs): - import sys - - if sys.version_info.minor < 9: - import astunparse - + def unparse(*args, **kwargs): return astunparse.unparse(*args, **kwargs) - else: +else: + + def unparse(*args, **kwargs): return ast.unparse(*args, **kwargs) @@ -26,12 +26,15 @@ def __init__(self, *, var_names: List[str]): self._var_names = var_names def transform(self, expr: str) -> Tuple[str, bool]: - node = ast.parse(expr, mode='eval') - transformed = self.visit(node) - if transformed is Unknown: - return expr, False + if expr: + node = ast.parse(expr, mode='eval') + transformed = self.visit(node) + if transformed is Unknown: + return expr, False + else: + return unparse(transformed), True else: - return unparse(transformed), True + return expr, False def visit_Expression(self, node: ast.Expression) -> Any: node: ast.Expression = self.generic_visit(node) From 86deb77b35928757da2e9f8b2f7246c0451c3beb Mon Sep 17 00:00:00 2001 From: mihran113 Date: Thu, 13 Mar 2025 15:40:04 +0400 Subject: [PATCH 14/30] [chore] Remove legacy (`aim 2.x.x`) sdk (#3305) --- aim/sdk/__init__.py | 7 -- aim/sdk/legacy/__init__.py | 0 aim/sdk/legacy/deprecation_warning.py | 15 --- aim/sdk/legacy/flush.py | 6 -- aim/sdk/legacy/init.py | 7 -- aim/sdk/legacy/select.py | 30 ------ aim/sdk/legacy/session/__init__.py | 1 - aim/sdk/legacy/session/configs.py | 1 - aim/sdk/legacy/session/session.py | 132 -------------------------- aim/sdk/legacy/track.py | 16 ---- 10 files changed, 215 deletions(-) delete mode 100644 aim/sdk/legacy/__init__.py delete mode 100644 aim/sdk/legacy/deprecation_warning.py delete mode 100644 aim/sdk/legacy/flush.py delete mode 100644 aim/sdk/legacy/init.py delete mode 100644 aim/sdk/legacy/select.py delete mode 100644 aim/sdk/legacy/session/__init__.py delete mode 100644 aim/sdk/legacy/session/configs.py delete mode 100644 aim/sdk/legacy/session/session.py delete mode 100644 aim/sdk/legacy/track.py diff --git a/aim/sdk/__init__.py b/aim/sdk/__init__.py index 17d6974a6..f7c190da1 100644 --- a/aim/sdk/__init__.py +++ b/aim/sdk/__init__.py @@ -1,10 +1,3 @@ -# Legacy SDK functions -from aim.sdk.legacy.flush import flush -from aim.sdk.legacy.init import init -from aim.sdk.legacy.select import select_metrics, select_runs -from aim.sdk.legacy.session import Session -from aim.sdk.legacy.track import set_params, track - # pre-defined sequences and custom objects from aim.sdk.objects import Audio, Distribution, Figure, Image, Text from aim.sdk.repo import Repo diff --git a/aim/sdk/legacy/__init__.py b/aim/sdk/legacy/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/aim/sdk/legacy/deprecation_warning.py b/aim/sdk/legacy/deprecation_warning.py deleted file mode 100644 index 36047509e..000000000 --- a/aim/sdk/legacy/deprecation_warning.py +++ /dev/null @@ -1,15 +0,0 @@ -import logging - -from functools import wraps - - -logger = logging.getLogger(__name__) - - -def deprecated(func): - @wraps(func) - def wrapper(*args, **kwargs): - logger.warning(msg=f'Usage of {func.__qualname__} is deprecated!') - return func(*args, **kwargs) - - return wrapper diff --git a/aim/sdk/legacy/flush.py b/aim/sdk/legacy/flush.py deleted file mode 100644 index 74dc48c2a..000000000 --- a/aim/sdk/legacy/flush.py +++ /dev/null @@ -1,6 +0,0 @@ -from aim.sdk.legacy.deprecation_warning import deprecated - - -@deprecated -def flush(): - pass diff --git a/aim/sdk/legacy/init.py b/aim/sdk/legacy/init.py deleted file mode 100644 index 6456e4950..000000000 --- a/aim/sdk/legacy/init.py +++ /dev/null @@ -1,7 +0,0 @@ -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.legacy.session import DefaultSession - - -@deprecated -def init(*args, **kwargs): - DefaultSession(*args, **kwargs) diff --git a/aim/sdk/legacy/select.py b/aim/sdk/legacy/select.py deleted file mode 100644 index 4b637d13f..000000000 --- a/aim/sdk/legacy/select.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Optional - -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.repo import Repo - - -@deprecated -def select_metrics(search_statement: str, repo_path: Optional[str] = None): - if repo_path is not None: - repo = Repo.from_path(repo_path) - else: - repo = Repo.default_repo() - - if not repo: - return None - - return repo.query_metrics(search_statement) - - -@deprecated -def select_runs(expression: Optional[str] = None, repo_path: Optional[str] = None): - if repo_path is not None: - repo = Repo.from_path(repo_path) - else: - repo = Repo.default_repo() - - if not repo: - return None - - return repo.query_runs(expression) diff --git a/aim/sdk/legacy/session/__init__.py b/aim/sdk/legacy/session/__init__.py deleted file mode 100644 index 6c268677a..000000000 --- a/aim/sdk/legacy/session/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from aim.sdk.legacy.session.session import DefaultSession, Session diff --git a/aim/sdk/legacy/session/configs.py b/aim/sdk/legacy/session/configs.py deleted file mode 100644 index 08db0117c..000000000 --- a/aim/sdk/legacy/session/configs.py +++ /dev/null @@ -1 +0,0 @@ -DEFAULT_FLUSH_FREQUENCY = 128 diff --git a/aim/sdk/legacy/session/session.py b/aim/sdk/legacy/session/session.py deleted file mode 100644 index cc77b9744..000000000 --- a/aim/sdk/legacy/session/session.py +++ /dev/null @@ -1,132 +0,0 @@ -import atexit -import os -import signal -import threading - -from typing import Optional - -from aim.ext.exception_resistant import exception_resistant -from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.repo import Repo -from aim.sdk.run import Run - - -class Session: - sessions = {} - - _are_exit_listeners_set = False - _original_sigint_handler = None - _original_sigterm_handler = None - - @deprecated - def __init__( - self, - repo: Optional[str] = None, - experiment: Optional[str] = None, - flush_frequency: int = 0, # unused - block_termination: bool = True, # unused - run: Optional[str] = None, - system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, - ): - self._repo = Repo.from_path(repo) if repo else Repo.default_repo() - self._repo_path = self._repo.path - self._run = Run(run, repo=self._repo, experiment=experiment, system_tracking_interval=system_tracking_interval) - self._run_hash = self._run.hash - self.active = True - - Session.sessions.setdefault(self._repo_path, []) - Session.sessions[self._repo_path].append(self) - - # Bind signal listeners - self._set_exit_handlers() - - @property - def run_hash(self): - return self._run_hash - - @property - def repo_path(self): - return self._repo_path - - @exception_resistant(silent=False) - def track(self, *args, **kwargs): - val = args[0] - name = kwargs.pop('name') - step = kwargs.pop('step', None) - epoch = kwargs.pop('epoch', None) - for key in kwargs.keys(): - if key.startswith('__'): - del kwargs[key] - - self._run.track(val, name=name, step=step, epoch=epoch, context=kwargs) - - @exception_resistant(silent=False) - def set_params(self, params: dict, name: Optional[str] = None): - if name is None: - self._run[...] = params - else: - self._run[name] = params - - def flush(self): - pass - - @exception_resistant(silent=False) - def close(self): - if not self.active: - raise Exception('session is closed') - if self._run: - del self._run - self._run = None - if self._repo_path in Session.sessions and self in Session.sessions[self._repo_path]: - Session.sessions[self._repo_path].remove(self) - if len(Session.sessions[self._repo_path]) == 0: - del Session.sessions[self._repo_path] - self.active = False - - @classmethod - def _close_sessions(cls, *args, **kwargs): - threads = [] - for _, sessions in cls.sessions.items(): - for session in sessions: - th = threading.Thread(target=session.close) - th.daemon = True - threads.append(th) - - for th in threads: - th.start() - - for th in threads: - th.join() - - if len(args): - if args[0] == 15: - signal.signal(signal.SIGTERM, cls._original_sigterm_handler) - os.kill(os.getpid(), 15) - # elif args[0] == 2: - # signal.signal(signal.SIGINT, cls._original_sigint_handler) - # os.kill(os.getpid(), 2) - - @classmethod - def _set_exit_handlers(cls): - if not cls._are_exit_listeners_set: - cls._are_exit_listeners_set = True - # cls._original_sigint_handler = signal.getsignal(signal.SIGINT) - cls._original_sigterm_handler = signal.getsignal(signal.SIGTERM) - - atexit.register(cls._close_sessions) - # signal.signal(signal.SIGINT, cls._close_sessions) - signal.signal(signal.SIGTERM, cls._close_sessions) - - -DefaultSession = Session - - -def get_default_session() -> Session: - if len(Session.sessions.keys()) > 0: - default_sess_key = list(Session.sessions.keys())[0] - if len(Session.sessions[default_sess_key]) > 0: - return Session.sessions[default_sess_key][0] - - # Create and return default session otherwise - return DefaultSession() diff --git a/aim/sdk/legacy/track.py b/aim/sdk/legacy/track.py deleted file mode 100644 index ef04f2bcc..000000000 --- a/aim/sdk/legacy/track.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Optional - -from aim.sdk.legacy.deprecation_warning import deprecated -from aim.sdk.legacy.session.session import get_default_session - - -@deprecated -def track(*args, **kwargs): - sess = get_default_session() - return sess.track(*args, **kwargs) - - -@deprecated -def set_params(params: dict, name: Optional[str] = None): - sess = get_default_session() - return sess.set_params(params, name) From 795067ce475a67fd62ef360efc69515477dbf8d8 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Thu, 13 Mar 2025 15:42:47 +0400 Subject: [PATCH 15/30] [fix] Improve error messages for remote tracking (#3303) --- CHANGELOG.md | 3 ++- aim/ext/transport/message_utils.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dcc243e4..9c267f09e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,11 +7,12 @@ - Fix connection of data points on epoch alignment (mihran113) - Resolve issue with adding duplicate tags to the same run (mihran113) - Resolve session refresh issues when db file is replaced (mihran113) +- Improve error messages for remote tracking server (mihran113) +- Fix spurious assertion error in message stream parsing (qzed) ### Enhancements: - Skip metrics check when run is known to yield false result (alberttorosyan) - Correct indentation on query proxy object return statement (alberttorosyan) -- Fix spurious assertion error in message stream parsing (qzed) ## 3.27.0 Dec 18, 2024 diff --git a/aim/ext/transport/message_utils.py b/aim/ext/transport/message_utils.py index cc0422e5a..20f9c717f 100644 --- a/aim/ext/transport/message_utils.py +++ b/aim/ext/transport/message_utils.py @@ -55,7 +55,9 @@ def raise_exception(server_exception): module = importlib.import_module(server_exception.get('module_name')) exception = getattr(module, server_exception.get('class_name')) args = json.loads(server_exception.get('args') or []) - raise exception(*args) if args else exception() + message = server_exception.get('message') + + raise exception(*args) if args else Exception(message) def build_exception(exception: Exception): @@ -63,6 +65,7 @@ def build_exception(exception: Exception): 'module_name': exception.__class__.__module__, 'class_name': exception.__class__.__name__, 'args': json.dumps(exception.args), + 'message': str(exception), } From 5bafebb91aaed9ec541a8774a8105329d16df59b Mon Sep 17 00:00:00 2001 From: Vassilis Vassiliadis <43679502+VassilisVassiliadis@users.noreply.github.com> Date: Thu, 13 Mar 2025 11:54:28 +0000 Subject: [PATCH 16/30] [feat] Add AimCallback for distributed runs using the hugging face API (#3284) There is a singular aim.Run which the main worker initializes and manages. All auxiliary workers (local_rank 0 workers hosted on other nodes) collect their metrics and forward them to the main worker. The main worker records the metrics in AIM. Signed-off-by: Vassilis Vassiliadis --- aim/distributed_hugging_face.py | 2 + aim/sdk/adapters/distributed_hugging_face.py | 498 +++++++++++++++++++ 2 files changed, 500 insertions(+) create mode 100644 aim/distributed_hugging_face.py create mode 100644 aim/sdk/adapters/distributed_hugging_face.py diff --git a/aim/distributed_hugging_face.py b/aim/distributed_hugging_face.py new file mode 100644 index 000000000..0fa836d02 --- /dev/null +++ b/aim/distributed_hugging_face.py @@ -0,0 +1,2 @@ +# Alias to SDK distributed hugging face interface +from aim.sdk.adapters.distributed_hugging_face import AimCallback # noqa F401 diff --git a/aim/sdk/adapters/distributed_hugging_face.py b/aim/sdk/adapters/distributed_hugging_face.py new file mode 100644 index 000000000..76e5e6937 --- /dev/null +++ b/aim/sdk/adapters/distributed_hugging_face.py @@ -0,0 +1,498 @@ +import os + +from aim.ext.resource.stat import Stat + +try: + import accelerate.utils.environment +except ImportError: + raise RuntimeError( + "This contrib module requires HuggingFace Accelerate to be installed. " + "Please install it with command: \n pip install accelerate" + ) + +import copy +import struct +import threading +import time + +import aim +import aim.hugging_face +import aim.ext.resource +import aim.sdk.configs + +import typing +import socket +import select +import logging +import json + + +class IncompletePackageError(Exception): + pass + + +class IncompleteHeaderError(IncompletePackageError): + pass + + +class IncompleteDataError(IncompletePackageError): + pass + + +def packet_encode(usage: typing.Dict[str, typing.Any]) -> bytes: + data = json.dumps(usage) + header = len(data).to_bytes(4, "big") + packet = b"".join((header, struct.pack(f"!{len(data)}s", data.encode("utf-8")))) + return packet + + +def packet_decode(packet: bytes) -> typing.Dict[str, typing.Any]: + length = int.from_bytes(packet[:4], "big") + raw = struct.unpack_from(f"!{length}s", packet, 4)[0] + decoded = json.loads(raw) + return decoded + + +class ResourceTrackerForwarder(aim.ext.resource.ResourceTracker): + def _track(self, stat: Stat): + # Instead of tracking individual system metrics, forward the entire update to the MetricsReporter + # in turn, the MetricsReporter will create a packet ouf of Stat and its context (rank, world_size, etc). + # Next, it'll send that packet to the MetricsReceiver which will then push the data to the Aim server + self._tracker()(stat) + + +class MetricsReporter: + def __init__( + self, + host: str, + port: int, + node_rank: int, + rank: int, + interval: typing.Union[int, float], + ): + self.client: typing.Optional[socket.socket] = None + + self.node_rank = node_rank + self.rank = rank + self.log = logging.getLogger(f"MetricsReporter{rank}") + + self._connect(host=host, port=port) + self.tracker = ResourceTrackerForwarder( + tracker=self, interval=interval, capture_logs=False + ) + + def start(self): + self.tracker.start() + + def stop(self): + if self.tracker._shutdown is False: + self.tracker.stop() + if self.client is not None: + self.client.close() + self.client = None + + def _connect( + self, + host: str, + port: int, + connection_timeout: int = 60 * 10, + retry_seconds: int = 5, + ): + start = time.time() + + while time.time() - start <= connection_timeout: + # This should deal with both ipv4 and ipv6 hosts + for family, socktype, proto, canonname, sa in socket.getaddrinfo( + host, port, proto=socket.SOL_TCP + ): + self.client = socket.socket(family, socktype, proto) + try: + self.client.connect(sa) + return + except (ConnectionRefusedError, OSError) as e: + self.client.close() + self.log.info( + f"Could not connect to main worker due to {e} - " + f"will retry in {retry_seconds} seconds" + ) + time.sleep(retry_seconds) + + raise ConnectionError( + f"Could not connect to server {host}:{port} after {connection_timeout} seconds" + ) + + def __call__(self, stat: aim.ext.resource.tracker.Stat): + if self.client is None: + self.log.info( + "Connection has already closed, will not propagate this system metrics snapshot" + ) + return + + # This is invoked by @self.tracker + raw = { + "stat": stat.stat_item.to_dict(), + "worker": { + "node_rank": self.node_rank, + "rank": self.rank, + }, + } + self.log.debug(f"Send {raw}") + + packet = packet_encode(raw) + try: + self.client.sendall(packet) + except BrokenPipeError: + self.log.info( + f"BrokenPipeError while transmitting system metrics {raw} - will stop recording system metrics" + ) + try: + self.stop() + except RuntimeError as e: + if e.args[0] != "cannot join current thread": + # Calling stop() causes self.tracker() to try to join this thread. In turn that raises + # this RuntimeError + raise + except Exception as e: + self.log.info( + f"{e} while transmitting system metrics {raw} - will ignore exception" + ) + + +class MetricsReceiver: + def __init__( + self, + host: str, + port: int, + num_workers: int, + connection_timeout: int, + ): + self.tracker: typing.Optional[ + typing.Callable[ + [ + typing.Dict[str, typing.Any], + typing.Dict[str, typing.Any], + ], + None, + ] + ] = None + + self.clients: typing.List[socket.socket] = [] + self.log = logging.getLogger("MetricsReceiver") + self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + self._wait_workers( + host=host, + port=port, + num_workers=num_workers, + connection_timeout=connection_timeout, + ) + + self.running = True + self.thread: typing.Optional[threading.Thread] = None + + def start( + self, + tracker: typing.Callable[ + [ + typing.Dict[str, typing.Any], + typing.Dict[str, typing.Any], + ], + None, + ], + ): + self.tracker = tracker + self.running = True + + self.thread = threading.Thread(target=self._collect_metrics, daemon=True) + self.thread.start() + + def stop(self): + if self.running: + self.running = False + self.thread.join() + + def _recv(self, sock: socket.socket, length: int) -> typing.Optional[bytes]: + data = b"" + retries = 0 + while len(data) < length and retries < 10: + buf = sock.recv(length - len(data)) + data += buf + retries += 1 + + if len(data) < length: + if len(data) > 0: + raise IncompletePackageError() + # If recv() returned b'' then the client disconnected + return None + + return data + + def _recv_packet( + self, sock: socket.socket + ) -> typing.Optional[typing.Dict[str, typing.Any]]: + try: + header = self._recv(sock, 4) + + if header is None: + # The client disconnected + return None + + length = int.from_bytes(header, "big") + except IncompletePackageError: + raise IncompleteHeaderError() + try: + data = self._recv(sock, length) + if len(data) > 0: + return json.loads(data) + except IncompletePackageError: + raise IncompleteDataError() + + def _collect_metrics(self): + while self.running: + read, _write, _error = select.select(self.clients, [], [], 5.0) + for client in typing.cast(typing.List[socket.socket], read): + try: + packet = self._recv_packet(client) + except IncompletePackageError as e: + self.log.info( + f"Error {e} while receiving update - will assume this is a transient error" + ) + continue + + if packet: + self.tracker(packet["stat"], packet["worker"]) + else: + self.log.info("Client disconnected") + client.close() + self.clients.remove(client) + + def _wait_workers( + self, host: str, port: int, num_workers: int, connection_timeout: float + ): + # This may raise an exception, don't catch it here and let it flow to the caller + self.server.bind((host, port)) + self.server.listen(num_workers) + + # We're actually going to pause here, till we get all clients OR we run out of time + + start = time.time() + self.log.info(f"Waiting for {num_workers} workers to connect") + # Block for 5 seconds while waiting for new connections + while time.time() - start <= connection_timeout: + read, _write, _error = select.select([self.server], [], [], 5.0) + for server in read: + client, _client_address = server.accept() + self.clients.append(client) + self.log.info(f"Client {len(self.clients)}/{num_workers} connected") + + if len(self.clients) == num_workers: + return + + self.server.close() + + raise ConnectionError( + f"{num_workers - len(self.clients)} out of {num_workers} total clients did not connect" + ) + + +class AimCallback(aim.hugging_face.AimCallback): + def __init__( + self, + main_port: int, + repo: typing.Optional[str] = None, + experiment: typing.Optional[str] = None, + system_tracking_interval: typing.Optional[ + int + ] = aim.ext.resource.DEFAULT_SYSTEM_TRACKING_INT, + log_system_params: typing.Optional[bool] = True, + capture_terminal_logs: typing.Optional[bool] = True, + main_addr: typing.Optional[str] = None, + distributed_information: typing.Optional[ + accelerate.utils.environment.CPUInformation + ] = None, + connection_timeout: int = 60 * 5, + workers_only_on_rank_0: bool = True, + ): + """A HuggingFace TrainerCallback which registers the system metrics of all workers involved in the training + under a single Aim run. + + This code initializes aim.hugging_face.AimCallback() only on rank 0 - otherwise we'd end up with multiple + Aim runs. + + Args: + main_port (:obj:`int`): Configures the port that the main worker will listen on. If this is None + then the code will raise an exception. + repo (:obj:`Union[Repo,str]`, optional): Aim repository path or Repo object to which Run object is bound. + If skipped, default Repo is used. + experiment (:obj:`str`, optional): Sets Run's `experiment` property. 'default' if not specified. + Can be used later to query runs/sequences. + system_tracking_interval (:obj:`int`, optional): Sets the tracking interval in seconds for system usage + metrics (CPU, Memory, etc.). Set to `None` to disable system metrics tracking. + log_system_params (:obj:`bool`, optional): Enable/Disable logging of system params such as installed + packages, git info, environment variables, etc. + main_addr (:obj:`str`, optional): The address of the main worker. If this is None then the code will + auto-discover it from the environment variable MASTER_ADDR. If this parameter cannot be resolved + to a non-empty value the method will raise an exception. + distributed_information (:obj:`str`, accelerate.utils.environment.CPUInformation): information about the + CPU in a distributed environment. If None, the code parses environment variables to auto create it. + See accelerate.utils.get_cpu_distributed_information() for more details + connection_timeout (:obj:`int`, optional): Maximum seconds to wait for the auxiliary workers to connect. + workers_only_on_rank_0 (:obj:`bool`): When set to true, only treat processes with local_rank 0 as + workers. Setting this to False, only makes sense when debugging the AimCallback() code. + + Raises: + ConnectionError: + If unable auxiliary workers are unable to connect to main worker + """ + if main_addr is None: + main_addr = os.environ.get("MASTER_ADDR") + + if not main_addr: + raise ValueError("main_addr cannot be empty") + + if not main_port or main_port <0: + raise ValueError("main_port must be a positive number") + + if distributed_information is None: + distributed_information = accelerate.utils.get_cpu_distributed_information() + + self.distributed_information = distributed_information + self.connection_timeout = connection_timeout + + self.listening_socket: typing.Optional[socket.socket] = None + + self.metrics_reporter: typing.Optional[MetricsReporter] = None + self.metrics_receiver: typing.Optional[MetricsReceiver] = None + + self._run: typing.Optional[aim.Run] = None + self.log = logging.getLogger("CustomAimCallback") + + if not workers_only_on_rank_0: + # This is primarily for debugging. It enables the creation of multiple auxiliary workers on a single node + auxiliary_workers = self.distributed_information.world_size + else: + auxiliary_workers = ( + self.distributed_information.world_size + // self.distributed_information.local_world_size + ) + + # Instantiate a MetricsReporter for all workers which are not rank 0 + if ( + self.distributed_information.rank is not None + and self.distributed_information.rank > 0 + and ( + not workers_only_on_rank_0 + or self.distributed_information.local_rank == 0 + ) + and system_tracking_interval is not None + ): + if workers_only_on_rank_0: + node_rank = ( + distributed_information.rank + // distributed_information.local_world_size + ) + else: + node_rank = distributed_information.rank + + self.metrics_reporter = MetricsReporter( + host=main_addr, + port=main_port, + rank=self.distributed_information.rank, + node_rank=node_rank, + interval=system_tracking_interval, + ) + + self.log.info( + f"Distributed worker {self.distributed_information.rank} connected" + ) + elif self.distributed_information.rank == 0: + # When running as the main worker, we initialize aim as usual. If there're multiple + # auxiliary workers, we also start a listening server. The auxiliary workers will connect + # to this server and periodically send over their system metrics + super().__init__( + repo, + experiment, + system_tracking_interval, + log_system_params, + capture_terminal_logs, + ) + + if ( + auxiliary_workers > 1 + and main_port is not None + and system_tracking_interval is not None + ): + self.log.info(f"There are {auxiliary_workers} workers") + + self.metrics_receiver = MetricsReceiver( + # Bind to 0.0.0.0 so that we can accept connections coming in from any interface + host="0.0.0.0", + port=main_port, + num_workers=auxiliary_workers - 1, + connection_timeout=self.connection_timeout, + ) + + self.metrics_receiver.start(self._push_auxiliary_worker_metrics) + + def _push_auxiliary_worker_metrics( + self, + stat: typing.Dict[str, typing.Any], + worker_info: typing.Dict[str, typing.Any], + ): + """Utility method which pushes the system metrics of an auxiliary worker to Aim + + Args: + stat: (:obj:`typing.Dict[str, typing.Any]`): A dictionary representation of + aim.ext.resource.stat.Stat + worker_info (:obj:`typing.Dict[str, typing.Any]`): A dictionary which represents the context of a + worker. For example, it can contain the fields {"rank": int, "node_rank": int} + """ + # TODO: Investigate whether it's better to spin up a dedicated RunTracker here or not + if self._run is None: + self.log.info( + f"The aim Run is inactive, will not register these metrics from {worker_info}" + ) + return + + tracker = self._run._tracker + context = copy.deepcopy(worker_info) + + for resource, usage in stat["system"].items(): + tracker( + usage, + name="{}{}".format( + aim.ext.resource.configs.AIM_RESOURCE_METRIC_PREFIX, + resource, + ), + context=context, + ) + + # Store GPU stats + for gpu_idx, gpu in enumerate(stat["gpus"]): + for resource, usage in gpu.items(): + context = copy.deepcopy(worker_info) + context.update({"gpu": gpu_idx}) + + tracker( + usage, + name="{}{}".format( + aim.ext.resource.configs.AIM_RESOURCE_METRIC_PREFIX, + resource, + ), + context=context, + ) + + def on_train_begin(self, args, state, control, model=None, **kwargs): + super().on_train_begin(args, state, control, model, **kwargs) + + if self.metrics_reporter: + self.metrics_reporter.start() + + def close(self): + try: + super().close() + finally: + if self.metrics_receiver is not None: + self.metrics_receiver.stop() + if self.metrics_reporter is not None: + self.metrics_reporter.stop() From c57f4a899672ab76f36b646ddecc64146a722b40 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Fri, 14 Mar 2025 18:16:39 +0400 Subject: [PATCH 17/30] [fix] Increase session pool size for sqlite engine (#3306) --- CHANGELOG.md | 1 - aim/storage/structured/db.py | 9 ++++++--- aim/web/api/db.py | 10 ++++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c267f09e..f44412ac9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,6 @@ - Decrease client resources keep-alive time (mihran113) - Fix connection of data points on epoch alignment (mihran113) - Resolve issue with adding duplicate tags to the same run (mihran113) -- Resolve session refresh issues when db file is replaced (mihran113) - Improve error messages for remote tracking server (mihran113) - Fix spurious assertion error in message stream parsing (qzed) diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index 2849b3be4..830c0bc41 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -61,8 +61,12 @@ def __init__(self, path: str, readonly: bool = False): self.db_url = self.get_db_url(path) self.readonly = readonly self.engine = create_engine( - self.db_url, echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))) + self.db_url, + echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), + pool_size=10, + max_overflow=20, ) + self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) self._upgraded = None @classmethod @@ -90,8 +94,7 @@ def caches(self): return self._caches def get_session(self, autocommit=True): - session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) - session = session_cls() + session = self.session_cls() setattr(session, 'autocommit', autocommit) return session diff --git a/aim/web/api/db.py b/aim/web/api/db.py index 80aeaa539..c38e2598f 100644 --- a/aim/web/api/db.py +++ b/aim/web/api/db.py @@ -14,13 +14,16 @@ get_db_url(), echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), connect_args={'check_same_thread': False}, + pool_size=10, + max_overflow=20, ) + +SessionLocal = sessionmaker(autoflush=False, bind=engine) Base = declarative_base() def get_session(): - session_cls = sessionmaker(autoflush=False, bind=engine) - session = session_cls() + session = SessionLocal() try: yield session finally: @@ -29,8 +32,7 @@ def get_session(): @contextmanager def get_contexted_session(): - session_cls = sessionmaker(autoflush=False, bind=engine) - session = session_cls() + session = SessionLocal() try: yield session finally: From f731d3e4597b1f15296ceba2df3fd60d11d3e5b5 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Tue, 18 Mar 2025 14:43:59 +0400 Subject: [PATCH 18/30] [feat] Remove metric version check to improve metric retrieval performance (#3307) --- CHANGELOG.md | 1 + aim/cli/up/commands.py | 2 +- aim/distributed_hugging_face.py | 2 +- aim/sdk/adapters/distributed_hugging_face.py | 156 +++++++----------- aim/sdk/run.py | 8 - aim/storage/encoding/encoding.pyx | 2 +- .../RunOverviewTab/RunOverviewTab.tsx | 7 - 7 files changed, 62 insertions(+), 116 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f44412ac9..0436f92e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ ### Enhancements: - Skip metrics check when run is known to yield false result (alberttorosyan) - Correct indentation on query proxy object return statement (alberttorosyan) +- Remove metric version check to improve performance of metric retrieval (mihran113) ## 3.27.0 Dec 18, 2024 diff --git a/aim/cli/up/commands.py b/aim/cli/up/commands.py index de8ed008d..aafd73b2b 100644 --- a/aim/cli/up/commands.py +++ b/aim/cli/up/commands.py @@ -29,7 +29,7 @@ @click.command() @click.option('-h', '--host', default=AIM_UI_DEFAULT_HOST, type=str) @click.option('-p', '--port', default=AIM_UI_DEFAULT_PORT, type=int) -@click.option('-w', '--workers', default=1, type=int) +@click.option('-w', '--workers', default=2, type=int) @click.option('--uds', required=False, type=click.Path(exists=False, file_okay=True, dir_okay=False, readable=True)) @click.option('--repo', required=False, type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True)) @click.option('--tf_logs', type=click.Path(exists=True, readable=True)) diff --git a/aim/distributed_hugging_face.py b/aim/distributed_hugging_face.py index 0fa836d02..cbb9f8eec 100644 --- a/aim/distributed_hugging_face.py +++ b/aim/distributed_hugging_face.py @@ -1,2 +1,2 @@ # Alias to SDK distributed hugging face interface -from aim.sdk.adapters.distributed_hugging_face import AimCallback # noqa F401 +from aim.sdk.adapters.distributed_hugging_face import AimCallback # noqa: F401 diff --git a/aim/sdk/adapters/distributed_hugging_face.py b/aim/sdk/adapters/distributed_hugging_face.py index 76e5e6937..561bef82d 100644 --- a/aim/sdk/adapters/distributed_hugging_face.py +++ b/aim/sdk/adapters/distributed_hugging_face.py @@ -2,30 +2,30 @@ from aim.ext.resource.stat import Stat + try: import accelerate.utils.environment except ImportError: raise RuntimeError( - "This contrib module requires HuggingFace Accelerate to be installed. " - "Please install it with command: \n pip install accelerate" + 'This contrib module requires HuggingFace Accelerate to be installed. ' + 'Please install it with command: \n pip install accelerate' ) import copy +import json +import logging +import select +import socket import struct import threading import time +import typing import aim -import aim.hugging_face import aim.ext.resource +import aim.hugging_face import aim.sdk.configs -import typing -import socket -import select -import logging -import json - class IncompletePackageError(Exception): pass @@ -41,14 +41,14 @@ class IncompleteDataError(IncompletePackageError): def packet_encode(usage: typing.Dict[str, typing.Any]) -> bytes: data = json.dumps(usage) - header = len(data).to_bytes(4, "big") - packet = b"".join((header, struct.pack(f"!{len(data)}s", data.encode("utf-8")))) + header = len(data).to_bytes(4, 'big') + packet = b''.join((header, struct.pack(f'!{len(data)}s', data.encode('utf-8')))) return packet def packet_decode(packet: bytes) -> typing.Dict[str, typing.Any]: - length = int.from_bytes(packet[:4], "big") - raw = struct.unpack_from(f"!{length}s", packet, 4)[0] + length = int.from_bytes(packet[:4], 'big') + raw = struct.unpack_from(f'!{length}s', packet, 4)[0] decoded = json.loads(raw) return decoded @@ -74,12 +74,10 @@ def __init__( self.node_rank = node_rank self.rank = rank - self.log = logging.getLogger(f"MetricsReporter{rank}") + self.log = logging.getLogger(f'MetricsReporter{rank}') self._connect(host=host, port=port) - self.tracker = ResourceTrackerForwarder( - tracker=self, interval=interval, capture_logs=False - ) + self.tracker = ResourceTrackerForwarder(tracker=self, interval=interval, capture_logs=False) def start(self): self.tracker.start() @@ -102,9 +100,7 @@ def _connect( while time.time() - start <= connection_timeout: # This should deal with both ipv4 and ipv6 hosts - for family, socktype, proto, canonname, sa in socket.getaddrinfo( - host, port, proto=socket.SOL_TCP - ): + for family, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, proto=socket.SOL_TCP): self.client = socket.socket(family, socktype, proto) try: self.client.connect(sa) @@ -112,50 +108,43 @@ def _connect( except (ConnectionRefusedError, OSError) as e: self.client.close() self.log.info( - f"Could not connect to main worker due to {e} - " - f"will retry in {retry_seconds} seconds" + f'Could not connect to main worker due to {e} - will retry in {retry_seconds} seconds' ) time.sleep(retry_seconds) - raise ConnectionError( - f"Could not connect to server {host}:{port} after {connection_timeout} seconds" - ) + raise ConnectionError(f'Could not connect to server {host}:{port} after {connection_timeout} seconds') def __call__(self, stat: aim.ext.resource.tracker.Stat): if self.client is None: - self.log.info( - "Connection has already closed, will not propagate this system metrics snapshot" - ) + self.log.info('Connection has already closed, will not propagate this system metrics snapshot') return # This is invoked by @self.tracker raw = { - "stat": stat.stat_item.to_dict(), - "worker": { - "node_rank": self.node_rank, - "rank": self.rank, + 'stat': stat.stat_item.to_dict(), + 'worker': { + 'node_rank': self.node_rank, + 'rank': self.rank, }, } - self.log.debug(f"Send {raw}") + self.log.debug(f'Send {raw}') packet = packet_encode(raw) try: self.client.sendall(packet) except BrokenPipeError: self.log.info( - f"BrokenPipeError while transmitting system metrics {raw} - will stop recording system metrics" + f'BrokenPipeError while transmitting system metrics {raw} - will stop recording system metrics' ) try: self.stop() except RuntimeError as e: - if e.args[0] != "cannot join current thread": + if e.args[0] != 'cannot join current thread': # Calling stop() causes self.tracker() to try to join this thread. In turn that raises # this RuntimeError raise except Exception as e: - self.log.info( - f"{e} while transmitting system metrics {raw} - will ignore exception" - ) + self.log.info(f'{e} while transmitting system metrics {raw} - will ignore exception') class MetricsReceiver: @@ -177,7 +166,7 @@ def __init__( ] = None self.clients: typing.List[socket.socket] = [] - self.log = logging.getLogger("MetricsReceiver") + self.log = logging.getLogger('MetricsReceiver') self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._wait_workers( @@ -212,7 +201,7 @@ def stop(self): self.thread.join() def _recv(self, sock: socket.socket, length: int) -> typing.Optional[bytes]: - data = b"" + data = b'' retries = 0 while len(data) < length and retries < 10: buf = sock.recv(length - len(data)) @@ -227,9 +216,7 @@ def _recv(self, sock: socket.socket, length: int) -> typing.Optional[bytes]: return data - def _recv_packet( - self, sock: socket.socket - ) -> typing.Optional[typing.Dict[str, typing.Any]]: + def _recv_packet(self, sock: socket.socket) -> typing.Optional[typing.Dict[str, typing.Any]]: try: header = self._recv(sock, 4) @@ -237,7 +224,7 @@ def _recv_packet( # The client disconnected return None - length = int.from_bytes(header, "big") + length = int.from_bytes(header, 'big') except IncompletePackageError: raise IncompleteHeaderError() try: @@ -254,21 +241,17 @@ def _collect_metrics(self): try: packet = self._recv_packet(client) except IncompletePackageError as e: - self.log.info( - f"Error {e} while receiving update - will assume this is a transient error" - ) + self.log.info(f'Error {e} while receiving update - will assume this is a transient error') continue if packet: - self.tracker(packet["stat"], packet["worker"]) + self.tracker(packet['stat'], packet['worker']) else: - self.log.info("Client disconnected") + self.log.info('Client disconnected') client.close() self.clients.remove(client) - def _wait_workers( - self, host: str, port: int, num_workers: int, connection_timeout: float - ): + def _wait_workers(self, host: str, port: int, num_workers: int, connection_timeout: float): # This may raise an exception, don't catch it here and let it flow to the caller self.server.bind((host, port)) self.server.listen(num_workers) @@ -276,23 +259,21 @@ def _wait_workers( # We're actually going to pause here, till we get all clients OR we run out of time start = time.time() - self.log.info(f"Waiting for {num_workers} workers to connect") + self.log.info(f'Waiting for {num_workers} workers to connect') # Block for 5 seconds while waiting for new connections while time.time() - start <= connection_timeout: read, _write, _error = select.select([self.server], [], [], 5.0) for server in read: client, _client_address = server.accept() self.clients.append(client) - self.log.info(f"Client {len(self.clients)}/{num_workers} connected") + self.log.info(f'Client {len(self.clients)}/{num_workers} connected') if len(self.clients) == num_workers: return self.server.close() - raise ConnectionError( - f"{num_workers - len(self.clients)} out of {num_workers} total clients did not connect" - ) + raise ConnectionError(f'{num_workers - len(self.clients)} out of {num_workers} total clients did not connect') class AimCallback(aim.hugging_face.AimCallback): @@ -301,15 +282,11 @@ def __init__( main_port: int, repo: typing.Optional[str] = None, experiment: typing.Optional[str] = None, - system_tracking_interval: typing.Optional[ - int - ] = aim.ext.resource.DEFAULT_SYSTEM_TRACKING_INT, + system_tracking_interval: typing.Optional[int] = aim.ext.resource.DEFAULT_SYSTEM_TRACKING_INT, log_system_params: typing.Optional[bool] = True, capture_terminal_logs: typing.Optional[bool] = True, main_addr: typing.Optional[str] = None, - distributed_information: typing.Optional[ - accelerate.utils.environment.CPUInformation - ] = None, + distributed_information: typing.Optional[accelerate.utils.environment.CPUInformation] = None, connection_timeout: int = 60 * 5, workers_only_on_rank_0: bool = True, ): @@ -345,13 +322,13 @@ def __init__( If unable auxiliary workers are unable to connect to main worker """ if main_addr is None: - main_addr = os.environ.get("MASTER_ADDR") + main_addr = os.environ.get('MASTER_ADDR') if not main_addr: - raise ValueError("main_addr cannot be empty") + raise ValueError('main_addr cannot be empty') - if not main_port or main_port <0: - raise ValueError("main_port must be a positive number") + if not main_port or main_port < 0: + raise ValueError('main_port must be a positive number') if distributed_information is None: distributed_information = accelerate.utils.get_cpu_distributed_information() @@ -365,32 +342,23 @@ def __init__( self.metrics_receiver: typing.Optional[MetricsReceiver] = None self._run: typing.Optional[aim.Run] = None - self.log = logging.getLogger("CustomAimCallback") + self.log = logging.getLogger('CustomAimCallback') if not workers_only_on_rank_0: # This is primarily for debugging. It enables the creation of multiple auxiliary workers on a single node auxiliary_workers = self.distributed_information.world_size else: - auxiliary_workers = ( - self.distributed_information.world_size - // self.distributed_information.local_world_size - ) + auxiliary_workers = self.distributed_information.world_size // self.distributed_information.local_world_size # Instantiate a MetricsReporter for all workers which are not rank 0 if ( self.distributed_information.rank is not None and self.distributed_information.rank > 0 - and ( - not workers_only_on_rank_0 - or self.distributed_information.local_rank == 0 - ) + and (not workers_only_on_rank_0 or self.distributed_information.local_rank == 0) and system_tracking_interval is not None ): if workers_only_on_rank_0: - node_rank = ( - distributed_information.rank - // distributed_information.local_world_size - ) + node_rank = distributed_information.rank // distributed_information.local_world_size else: node_rank = distributed_information.rank @@ -402,9 +370,7 @@ def __init__( interval=system_tracking_interval, ) - self.log.info( - f"Distributed worker {self.distributed_information.rank} connected" - ) + self.log.info(f'Distributed worker {self.distributed_information.rank} connected') elif self.distributed_information.rank == 0: # When running as the main worker, we initialize aim as usual. If there're multiple # auxiliary workers, we also start a listening server. The auxiliary workers will connect @@ -417,16 +383,12 @@ def __init__( capture_terminal_logs, ) - if ( - auxiliary_workers > 1 - and main_port is not None - and system_tracking_interval is not None - ): - self.log.info(f"There are {auxiliary_workers} workers") + if auxiliary_workers > 1 and main_port is not None and system_tracking_interval is not None: + self.log.info(f'There are {auxiliary_workers} workers') self.metrics_receiver = MetricsReceiver( # Bind to 0.0.0.0 so that we can accept connections coming in from any interface - host="0.0.0.0", + host='0.0.0.0', port=main_port, num_workers=auxiliary_workers - 1, connection_timeout=self.connection_timeout, @@ -449,18 +411,16 @@ def _push_auxiliary_worker_metrics( """ # TODO: Investigate whether it's better to spin up a dedicated RunTracker here or not if self._run is None: - self.log.info( - f"The aim Run is inactive, will not register these metrics from {worker_info}" - ) + self.log.info(f'The aim Run is inactive, will not register these metrics from {worker_info}') return tracker = self._run._tracker context = copy.deepcopy(worker_info) - for resource, usage in stat["system"].items(): + for resource, usage in stat['system'].items(): tracker( usage, - name="{}{}".format( + name='{}{}'.format( aim.ext.resource.configs.AIM_RESOURCE_METRIC_PREFIX, resource, ), @@ -468,14 +428,14 @@ def _push_auxiliary_worker_metrics( ) # Store GPU stats - for gpu_idx, gpu in enumerate(stat["gpus"]): + for gpu_idx, gpu in enumerate(stat['gpus']): for resource, usage in gpu.items(): context = copy.deepcopy(worker_info) - context.update({"gpu": gpu_idx}) + context.update({'gpu': gpu_idx}) tracker( usage, - name="{}{}".format( + name='{}{}'.format( aim.ext.resource.configs.AIM_RESOURCE_METRIC_PREFIX, resource, ), diff --git a/aim/sdk/run.py b/aim/sdk/run.py index 59bc4d806..775aed973 100644 --- a/aim/sdk/run.py +++ b/aim/sdk/run.py @@ -567,14 +567,6 @@ def get_metric(self, name: str, context: Context) -> Optional['Metric']: Returns: :obj:`Metric` object if exists, `None` otherwise. """ - if self.read_only and not Run._metric_version_warning_shown: - if self.check_metrics_version(): - logger.warning( - f'Detected sub-optimal format metrics for Run {self.hash}. Consider upgrading repo ' - f'to improve queries performance:' - ) - logger.warning(f"aim storage --repo {self.repo.path} upgrade 3.11+ '*'") - Run._metric_version_warning_shown = True return self._get_sequence('metric', name, context) diff --git a/aim/storage/encoding/encoding.pyx b/aim/storage/encoding/encoding.pyx index 71f2ca40b..308627e01 100644 --- a/aim/storage/encoding/encoding.pyx +++ b/aim/storage/encoding/encoding.pyx @@ -22,7 +22,7 @@ from aim.storage.encoding.encoding_native cimport ( decode_double, decode_utf_8_str, ) -from aim.storage.encoding.encoding_native cimport decode_path # noqa F401 +from aim.storage.encoding.encoding_native cimport decode_path # noqa: F401 from aim.storage.utils import ArrayFlagType, ObjectFlagType, CustomObjectFlagType from aim.storage.utils import ArrayFlag, ObjectFlag from aim.storage.container import ContainerValue diff --git a/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx b/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx index 1cd4bb89d..24b28e31f 100644 --- a/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx +++ b/aim/web/ui/src/pages/RunDetail/RunOverviewTab/RunOverviewTab.tsx @@ -8,8 +8,6 @@ import { ANALYTICS_EVENT_KEYS } from 'config/analytics/analyticsKeysMap'; import * as analytics from 'services/analytics'; -import useRunMetricsBatch from '../hooks/useRunMetricsBatch'; - import GitInfoCard from './components/GitInfoCard'; import RunOverviewTabMetricsCard from './components/MetricsCard/RunOverviewTabMetricsCard'; import RunOverviewTabArtifactsCard from './components/ArtifactsCard/RunOverviewTabArtifactsCard'; @@ -28,11 +26,6 @@ function RunOverviewTab({ runData, runHash }: IRunOverviewTabProps) { const overviewSectionContentRef = React.useRef(null); const [containerHeight, setContainerHeight] = React.useState(0); - useRunMetricsBatch({ - runTraces: runData.runTraces, - runHash, - }); - React.useEffect(() => { analytics.pageView( ANALYTICS_EVENT_KEYS.runDetails.tabs['overview'].tabView, From 51b8435c42bef21c37f381d65f1c3014785dbaf0 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Thu, 20 Mar 2025 18:38:40 +0400 Subject: [PATCH 19/30] [fix] Improve RT exception handling (#3309) --- aim/ext/transport/message_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aim/ext/transport/message_utils.py b/aim/ext/transport/message_utils.py index 20f9c717f..ceb52fac2 100644 --- a/aim/ext/transport/message_utils.py +++ b/aim/ext/transport/message_utils.py @@ -52,12 +52,18 @@ def unpack_stream(stream) -> Tuple[bytes, bytes]: def raise_exception(server_exception): + from filelock import Timeout + module = importlib.import_module(server_exception.get('module_name')) exception = getattr(module, server_exception.get('class_name')) args = json.loads(server_exception.get('args') or []) message = server_exception.get('message') - raise exception(*args) if args else Exception(message) + # special handling for lock timeouts as they require lock argument which can't be passed over the network + if exception == Timeout: + raise Exception(message) + + raise exception(*args) if args else exception() def build_exception(exception: Exception): From fba908f153484fc64a807379bdc3ec1238e8ea98 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Thu, 20 Mar 2025 19:39:53 +0400 Subject: [PATCH 20/30] [feat] Move indexing thread to `aim up` main process (#3311) --- aim/cli/up/commands.py | 4 ++++ aim/web/api/__init__.py | 6 ------ aim/web/api/projects/views.py | 11 ----------- 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/aim/cli/up/commands.py b/aim/cli/up/commands.py index aafd73b2b..42a23f50c 100644 --- a/aim/cli/up/commands.py +++ b/aim/cli/up/commands.py @@ -11,6 +11,7 @@ get_repo_instance, set_log_level, ) +from aim.sdk.index_manager import RepoIndexManager from aim.sdk.repo import Repo from aim.sdk.utils import clean_repo_path from aim.web.configs import ( @@ -122,6 +123,9 @@ def up( if profiler: os.environ[AIM_PROFILER_KEY] = '1' + index_mng = RepoIndexManager.get_index_manager(repo_inst) + index_mng.start_indexing_thread() + try: server_cmd = build_uvicorn_command( 'aim.web.run:app', diff --git a/aim/web/api/__init__.py b/aim/web/api/__init__.py index 52b5095b0..cb553a1e4 100644 --- a/aim/web/api/__init__.py +++ b/aim/web/api/__init__.py @@ -23,7 +23,6 @@ def create_app(): max_age=86400, ) - from aim.sdk.index_manager import RepoIndexManager from aim.web.api.dashboard_apps.views import dashboard_apps_router from aim.web.api.dashboards.views import dashboards_router from aim.web.api.experiments.views import experiment_router @@ -36,11 +35,6 @@ def create_app(): from aim.web.api.views import statics_router from aim.web.configs import AIM_UI_BASE_PATH - # The indexing thread has to run in the same process as the uvicorn app itself. - # This allows sharing state of indexing using memory instead of process synchronization methods. - index_mng = RepoIndexManager.get_index_manager(Project().repo) - index_mng.start_indexing_thread() - api_app = FastAPI() api_app.add_middleware(GZipMiddleware, compresslevel=1) api_app.add_middleware(ResourceCleanupMiddleware) diff --git a/aim/web/api/projects/views.py b/aim/web/api/projects/views.py index c32cc5452..36856ba27 100644 --- a/aim/web/api/projects/views.py +++ b/aim/web/api/projects/views.py @@ -5,7 +5,6 @@ from logging import getLogger from typing import Optional, Tuple -from aim.sdk.index_manager import RepoIndexManager from aim.storage.locking import AutoFileLock from aim.web.api.projects.project import Project from aim.web.api.projects.pydantic_models import ( @@ -171,13 +170,3 @@ async def project_params_api(sequence: Optional[Tuple[str, ...]] = Query(()), ex } response.update(**project.repo.collect_sequence_info(sequence)) return response - - -@projects_router.get('/status/') -async def project_status_api(): - project = Project() - - if not project.exists(): - raise HTTPException(status_code=404) - - return RepoIndexManager.get_index_manager(project.repo).repo_status From e02b98bac3d287907e166b5a9b65a50eecc73e86 Mon Sep 17 00:00:00 2001 From: Albert Torosyan Date: Fri, 21 Mar 2025 12:51:19 +0400 Subject: [PATCH 21/30] Bump up Aim to v3.28.0 --- CHANGELOG.md | 16 +++++++++++----- aim/VERSION | 2 +- aim/web/ui/package.json | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0436f92e4..94d102dd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ # Changelog -## Unreleased +## 3.28.0 Mar 21, 2025 + +### Enhancements: +- Skip metrics check when run is known to yield false result (alberttorosyan) +- Remove metric version check to improve performance of metric retrieval (mihran113) +- Move indexing thread to main process of `aim up` (alberttorosyan) +- Add AimCallback implementation for hugging face distributed runs (VassilisVassiliadis) +- Add py.typed marker to allow usage of existing type annotations (bluenote10) + ### Fixes: - Decrease client resources keep-alive time (mihran113) @@ -8,11 +16,9 @@ - Resolve issue with adding duplicate tags to the same run (mihran113) - Improve error messages for remote tracking server (mihran113) - Fix spurious assertion error in message stream parsing (qzed) - -### Enhancements: -- Skip metrics check when run is known to yield false result (alberttorosyan) - Correct indentation on query proxy object return statement (alberttorosyan) -- Remove metric version check to improve performance of metric retrieval (mihran113) +- Fix typing issues in S3ArtifactStorage implementation (sbatchelder) + ## 3.27.0 Dec 18, 2024 diff --git a/aim/VERSION b/aim/VERSION index 8c5312044..a72fd67b6 100644 --- a/aim/VERSION +++ b/aim/VERSION @@ -1 +1 @@ -3.27.0 +3.28.0 diff --git a/aim/web/ui/package.json b/aim/web/ui/package.json index 567b5732a..c9a9976a0 100644 --- a/aim/web/ui/package.json +++ b/aim/web/ui/package.json @@ -1,6 +1,6 @@ { "name": "ui_v2", - "version": "3.27.0", + "version": "3.28.0", "private": true, "dependencies": { "@aksel/structjs": "^1.0.0", From 897459a48bd31021af50ee0b3b2c172077f41d09 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Tue, 1 Apr 2025 11:33:44 +0400 Subject: [PATCH 22/30] [feat] Constant indexing of in-progress Runs (#3310) --- aim/cli/runs/commands.py | 2 +- aim/cli/storage/commands.py | 4 +- aim/cli/up/commands.py | 4 +- aim/sdk/index_manager.py | 272 ++++++++++++++------------------- aim/sdk/repo.py | 26 ++-- aim/storage/rockscontainer.pyx | 2 +- setup.py | 1 + 7 files changed, 132 insertions(+), 179 deletions(-) diff --git a/aim/cli/runs/commands.py b/aim/cli/runs/commands.py index 9f7413db5..26a4b58d6 100644 --- a/aim/cli/runs/commands.py +++ b/aim/cli/runs/commands.py @@ -192,7 +192,7 @@ def update_metrics(ctx, yes): if not confirmed: return - index_manager = RepoIndexManager.get_index_manager(repo) + index_manager = RepoIndexManager.get_index_manager(repo, disable_monitoring=True) hashes = repo.list_all_runs() for run_hash in tqdm.tqdm(hashes, desc='Updating runs', total=len(hashes)): meta_tree = repo.request_tree('meta', run_hash, read_only=False, from_union=False) diff --git a/aim/cli/storage/commands.py b/aim/cli/storage/commands.py index 3210f7a69..32bfe01d5 100644 --- a/aim/cli/storage/commands.py +++ b/aim/cli/storage/commands.py @@ -51,7 +51,7 @@ def to_3_11(ctx, hashes, yes): if not confirmed: return - index_manager = RepoIndexManager.get_index_manager(repo) + index_manager = RepoIndexManager.get_index_manager(repo, disable_monitoring=True) for run_hash in tqdm(matched_hashes): try: run = Run(run_hash, repo=repo) @@ -97,7 +97,7 @@ def restore_runs(ctx, hashes, yes): return remaining_runs = [] - index_manager = RepoIndexManager.get_index_manager(repo) + index_manager = RepoIndexManager.get_index_manager(repo, disable_monitoring=True) for run_hash in tqdm(matched_hashes): try: restore_run_backup(repo, run_hash) diff --git a/aim/cli/up/commands.py b/aim/cli/up/commands.py index 42a23f50c..e294c4f9b 100644 --- a/aim/cli/up/commands.py +++ b/aim/cli/up/commands.py @@ -123,9 +123,7 @@ def up( if profiler: os.environ[AIM_PROFILER_KEY] = '1' - index_mng = RepoIndexManager.get_index_manager(repo_inst) - index_mng.start_indexing_thread() - + RepoIndexManager.get_index_manager(repo_inst) try: server_cmd = build_uvicorn_command( 'aim.web.run:app', diff --git a/aim/sdk/index_manager.py b/aim/sdk/index_manager.py index 7c26cb2bf..f7e01502b 100644 --- a/aim/sdk/index_manager.py +++ b/aim/sdk/index_manager.py @@ -1,191 +1,143 @@ -import contextlib -import datetime import logging import os -import time +import queue +import threading from pathlib import Path -from threading import Thread -from typing import Iterable import aimrocks.errors -import pytz from aim.sdk.repo import Repo -from aim.sdk.run_status_watcher import Event -from aim.storage.locking import RefreshLock +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer logger = logging.getLogger(__name__) +class NewChunkCreatedHandler(FileSystemEventHandler): + def __init__(self, manager): + self.manager = manager + + def on_created(self, event): + if event.is_directory: + chunk_name = os.path.basename(event.src_path) + logger.debug(f'Detected new chunk directory: {chunk_name}') + self.manager.monitor_chunk_directory(event.src_path) + + +class ChunkChangedHandler(FileSystemEventHandler): + def __init__(self, manager): + self.manager = manager + self.pending_events = set() + self.lock = threading.Lock() + + def _trigger_event(self, run_hash): + with self.lock: + if run_hash not in self.pending_events: + self.pending_events.add(run_hash) + threading.Timer(0.5, self._process_event, [run_hash]).start() + + def _process_event(self, run_hash): + with self.lock: + if run_hash in self.pending_events: + self.pending_events.remove(run_hash) + logger.debug(f'Triggering indexing for run {run_hash}') + self.manager.add_run_to_queue(run_hash) + + def on_any_event(self, event): + if event.is_directory: + return + + event_path = Path(event.src_path) + parent_dir = event_path.parent + run_hash = parent_dir.name + + # Ensure the parent directory is directly inside meta/chunks/ + if parent_dir.parent != self.manager.chunks_dir: + logger.debug(f'Skipping event outside valid chunk directory: {event.src_path}') + return + + if event_path.name.startswith('LOG'): + logger.debug(f'Skipping event for LOG-prefixed file: {event.src_path}') + return + + logger.debug(f'Detected change in {event.src_path}') + self._trigger_event(run_hash) + + class RepoIndexManager: index_manager_pool = {} - INDEXING_GRACE_PERIOD = 10 @classmethod - def get_index_manager(cls, repo: Repo): + def get_index_manager(cls, repo: Repo, disable_monitoring: bool = False): mng = cls.index_manager_pool.get(repo.path, None) if mng is None: - mng = RepoIndexManager(repo) + mng = RepoIndexManager(repo, disable_monitoring) cls.index_manager_pool[repo.path] = mng return mng - def __init__(self, repo: Repo): + def __init__(self, repo: Repo, disable_monitoring: bool): self.repo_path = repo.path self.repo = repo - self.progress_dir = Path(self.repo_path) / 'meta' / 'progress' - self.progress_dir.mkdir(parents=True, exist_ok=True) - - self.heartbeat_dir = Path(self.repo_path) / 'check_ins' - self.run_heartbeat_cache = {} + self.chunks_dir = Path(self.repo_path) / 'meta' / 'chunks' + self.chunks_dir.mkdir(parents=True, exist_ok=True) - self._indexing_in_progress = False - self._reindex_thread: Thread = None self._corrupted_runs = set() - @property - def repo_status(self): - if self._indexing_in_progress is True: - return 'indexing in progress' - if self.reindex_needed: - return 'needs indexing' - return 'up-to-date' - - @property - def reindex_needed(self) -> bool: - runs_with_progress = os.listdir(self.progress_dir) - return len(runs_with_progress) > 0 - - def start_indexing_thread(self): - logger.info(f"Starting indexing thread for repo '{self.repo_path}'") - self._reindex_thread = Thread(target=self._run_forever, daemon=True) - self._reindex_thread.start() - - def _run_forever(self): - idle_cycles = 0 - while True: - self._indexing_in_progress = False - for run_hash in self._next_stalled_run(): - logger.info(f'Found un-indexed run {run_hash}. Indexing...') - self._indexing_in_progress = True - idle_cycles = 0 - self.index(run_hash) - - # sleep for small interval to release index db lock in between and allow - # other running jobs to properly finalize and index Run. - sleep_interval = 0.1 - time.sleep(sleep_interval) - if not self._indexing_in_progress: - idle_cycles += 1 - sleep_interval = 2 * idle_cycles if idle_cycles < 5 else 10 - logger.info( - f'No un-indexed runs found. Next check will run in {sleep_interval} seconds. ' - f'Waiting for un-indexed run...' - ) - time.sleep(sleep_interval) - - def _runs_with_progress(self) -> Iterable[str]: - runs_with_progress = filter(lambda x: x not in self._corrupted_runs, os.listdir(self.progress_dir)) - run_hashes = sorted(runs_with_progress, key=lambda r: os.path.getmtime(os.path.join(self.progress_dir, r))) - return run_hashes - - def _next_stalled_run(self): - for run_hash in self._runs_with_progress(): - if self._is_run_stalled(run_hash): - yield run_hash - - def _is_run_stalled(self, run_hash: str) -> bool: - stalled = False - heartbeat_files = list(sorted(self.heartbeat_dir.glob(f'{run_hash}-*-progress-*-*'), reverse=True)) - if heartbeat_files: - last_heartbeat = Event(heartbeat_files[0].name) - last_recorded_heartbeat = self.run_heartbeat_cache.get(run_hash) - if last_recorded_heartbeat is None: - self.run_heartbeat_cache[run_hash] = last_heartbeat - elif last_heartbeat.idx > last_recorded_heartbeat.idx: - self.run_heartbeat_cache[run_hash] = last_heartbeat - else: - time_passed = time.time() - last_recorded_heartbeat.detected_epoch_time - if last_recorded_heartbeat.next_event_in + RepoIndexManager.INDEXING_GRACE_PERIOD < time_passed: - stalled = True + if not disable_monitoring: + self.indexing_queue = queue.PriorityQueue() + self.lock = threading.Lock() + + self.observer = Observer() + self.new_chunk_handler = NewChunkCreatedHandler(self) + self.chunk_change_handler = ChunkChangedHandler(self) + + self.observer.schedule(self.new_chunk_handler, self.chunks_dir, recursive=True) + self._monitor_existing_chunks() + self.observer.start() + + self._reindex_thread = threading.Thread(target=self._process_queue, daemon=True) + self._reindex_thread.start() + + def _monitor_existing_chunks(self): + for chunk_path in self.chunks_dir.iterdir(): + if chunk_path.is_dir(): + logger.debug(f'Monitoring existing chunk: {chunk_path}') + self.monitor_chunk_directory(chunk_path) + + def monitor_chunk_directory(self, chunk_path): + """Ensure chunk directory is monitored using a single handler.""" + if str(chunk_path) not in self.observer._watches: + self.observer.schedule(self.chunk_change_handler, chunk_path, recursive=True) + logger.debug(f'Started monitoring chunk directory: {chunk_path}') else: - stalled = True - return stalled + logger.debug(f'Chunk directory already monitored: {chunk_path}') - def _index_lock_path(self): - return Path(self.repo.path) / 'locks' / 'index' + def add_run_to_queue(self, run_hash): + if run_hash in self._corrupted_runs: + return + timestamp = os.path.getmtime(os.path.join(self.chunks_dir, run_hash)) + with self.lock: + self.indexing_queue.put((timestamp, run_hash)) + logger.debug(f'Run {run_hash} added to indexing queue with timestamp {timestamp}') - @contextlib.contextmanager - def lock_index(self, lock: RefreshLock): - try: - self._safe_acquire_lock(lock) - yield - finally: - lock.release() - - def _safe_acquire_lock(self, lock: RefreshLock): - last_touch_seen = None - prev_touch_time = None - last_owner_id = None + def _process_queue(self): while True: - try: - lock.acquire() - logger.debug('Lock is acquired!') - break - except TimeoutError: - owner_id = lock.owner_id() - if owner_id != last_owner_id: - logger.debug(f'Lock has been acquired by {owner_id}') - last_owner_id = owner_id - prev_touch_time = None - else: # same holder as from prev. iteration - last_touch_time = lock.last_refresh_time() - if last_touch_time != prev_touch_time: - prev_touch_time = last_touch_time - last_touch_seen = time.time() - logger.debug(f'Lock has been refreshed. Touch time: {last_touch_time}') - continue - assert last_touch_seen is not None - if time.time() - last_touch_seen > RefreshLock.GRACE_PERIOD: - logger.debug('Grace period exceeded. Force-acquiring the lock.') - with lock.meta_lock(): - # double check holder ID - if lock.owner_id() != last_owner_id: # someone else grabbed lock - continue - else: - lock.force_release() - try: - lock.acquire() - logger.debug('lock has been forcefully acquired!') - break - except TimeoutError: - continue - else: - logger.debug( - f'Countdown to force-acquire lock. ' - f'Time remaining: {RefreshLock.GRACE_PERIOD - (time.time() - last_touch_seen)}' - ) - - def run_needs_indexing(self, run_hash: str) -> bool: - return os.path.exists(self.progress_dir / run_hash) - - def index( - self, - run_hash, - ) -> bool: - lock = RefreshLock(self._index_lock_path(), timeout=10) - with self.lock_index(lock): - index = self.repo._get_index_tree('meta', 0).view(()) - try: - meta_tree = self.repo.request_tree( - 'meta', run_hash, read_only=True, from_union=False, no_cache=True - ).subtree('meta') - meta_run_tree = meta_tree.subtree('chunks').subtree(run_hash) - meta_run_tree.finalize(index=index) - if meta_run_tree.get('end_time') is None: - index['meta', 'chunks', run_hash, 'end_time'] = datetime.datetime.now(pytz.utc).timestamp() - except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): - logger.warning(f"Indexing thread detected corrupted run '{run_hash}'. Skipping.") - self._corrupted_runs.add(run_hash) - return True + _, run_hash = self.indexing_queue.get() + logger.debug(f'Indexing run {run_hash}...') + self.index(run_hash) + self.indexing_queue.task_done() + + def index(self, run_hash): + index = self.repo._get_index_tree('meta', 0).view(()) + try: + meta_tree = self.repo.request_tree( + 'meta', run_hash, read_only=True, from_union=False, no_cache=True, skip_read_optimization=True + ).subtree('meta') + meta_run_tree = meta_tree.subtree('chunks').subtree(run_hash) + meta_run_tree.finalize(index=index) + except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): + logger.warning(f"Indexing thread detected corrupted run '{run_hash}'. Skipping.") + self._corrupted_runs.add(run_hash) + return True diff --git a/aim/sdk/repo.py b/aim/sdk/repo.py index 6d2471cb2..b37838421 100644 --- a/aim/sdk/repo.py +++ b/aim/sdk/repo.py @@ -269,19 +269,22 @@ def get_version(cls, path: str): def is_remote_path(cls, path: str): return path.startswith('aim://') - def _get_container(self, name: str, read_only: bool, from_union: bool = False) -> Container: + def _get_container(self, name: str, read_only: bool, from_union: bool = False, skip_read_optimization: bool = False) -> Container: + # TODO [AT]: refactor get container/tree logic to make it more simple if self.read_only and not read_only: raise ValueError('Repo is read-only') container_config = ContainerConfig(name, None, read_only=read_only) container = self.container_pool.get(container_config) if container is None: - path = os.path.join(self.path, name) if from_union: - container = RocksUnionContainer(path, read_only=read_only) + # Temporarily use index db when getting data from union. + path = os.path.join(self.path, name, 'index') + container = RocksContainer(path, read_only=read_only, skip_read_optimization=skip_read_optimization) self.persistent_pool[container_config] = container else: - container = RocksContainer(path, read_only=read_only) + path = os.path.join(self.path, name) + container = RocksContainer(path, read_only=read_only, skip_read_optimization=skip_read_optimization) self.container_pool[container_config] = container return container @@ -314,9 +317,11 @@ def request_tree( read_only: bool, from_union: bool = False, # TODO maybe = True by default no_cache: bool = False, + skip_read_optimization: bool = False ): if not self.is_remote_repo: - return self.request(name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache).tree() + return self.request(name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache, + skip_read_optimization=skip_read_optimization).tree() else: return ProxyTree(self._client, name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache) @@ -328,6 +333,7 @@ def request( read_only: bool, from_union: bool = False, # TODO maybe = True by default no_cache: bool = False, + skip_read_optimization: bool = False ): container_config = ContainerConfig(name, sub, read_only) container_view = self.container_view_pool.get(container_config) @@ -338,7 +344,8 @@ def request( else: assert sub is not None path = os.path.join(name, 'chunks', sub) - container = self._get_container(path, read_only=True, from_union=from_union) + container = self._get_container(path, read_only=True, from_union=from_union, + skip_read_optimization=skip_read_optimization) else: assert sub is not None path = os.path.join(name, 'chunks', sub) @@ -1005,10 +1012,7 @@ def optimize_container(path, extra_options): if self.is_remote_repo: self._remote_repo_proxy._close_run(run_hash) - from aim.sdk.index_manager import RepoIndexManager - lock_manager = LockManager(self.path) - index_manager = RepoIndexManager.get_index_manager(self) if lock_manager.release_locks(run_hash, force=True): # Run rocksdb optimizations if container locks are removed @@ -1016,8 +1020,6 @@ def optimize_container(path, extra_options): seqs_db_path = os.path.join(self.path, 'seqs', 'chunks', run_hash) optimize_container(meta_db_path, extra_options={'compaction': True}) optimize_container(seqs_db_path, extra_options={}) - if index_manager.run_needs_indexing(run_hash): - index_manager.index(run_hash) def _recreate_index(self): from tqdm import tqdm @@ -1028,7 +1030,7 @@ def _recreate_index(self): from aim.sdk.index_manager import RepoIndexManager - index_manager = RepoIndexManager.get_index_manager(self) + index_manager = RepoIndexManager.get_index_manager(self, disable_monitoring=True) # force delete the index db and the locks diff --git a/aim/storage/rockscontainer.pyx b/aim/storage/rockscontainer.pyx index 1be6f9086..d21e28c32 100644 --- a/aim/storage/rockscontainer.pyx +++ b/aim/storage/rockscontainer.pyx @@ -144,7 +144,7 @@ class RocksContainer(Container): lock_cls = self.get_lock_cls() self._lock = lock_cls(self._lock_path, timeout) self._lock.acquire() - else: + elif not self._extra_opts.get('skip_read_optimization', False): self.optimize_for_read() self._db = aimrocks.DB(str(self.path), diff --git a/setup.py b/setup.py index 00725b280..983b38016 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ def package_files(directory): 'packaging>=15.0', 'python-dateutil', 'requests', + 'watchdog', 'websockets', 'boto3', ] From e206b50bfcfb21e9368b968f70f4391b44139231 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Wed, 2 Apr 2025 15:46:34 +0400 Subject: [PATCH 23/30] [fix] Resolve issue of min/max calculation for single point metrics (#3315) --- CHANGELOG.md | 5 +++++ aim/web/ui/src/utils/aggregateGroupData.ts | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94d102dd8..d55c7567e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased: + +### Fixes: +- Fix min/max calculation for single point metrics (mihran113) + ## 3.28.0 Mar 21, 2025 ### Enhancements: diff --git a/aim/web/ui/src/utils/aggregateGroupData.ts b/aim/web/ui/src/utils/aggregateGroupData.ts index 0a0acc69e..2fa05644a 100644 --- a/aim/web/ui/src/utils/aggregateGroupData.ts +++ b/aim/web/ui/src/utils/aggregateGroupData.ts @@ -113,6 +113,17 @@ export function aggregateGroupData({ } } } + // add special case handling for single point metrics + if (trace.xValues.length === 1) { + const step = trace.xValues[0]; + let value = chartXValues.indexOf(step); + let y = trace.yValues[0]; + if (yValuesPerX.hasOwnProperty(value)) { + yValuesPerX[value].push(y); + } else { + yValuesPerX[value] = [y]; + } + } } } From 943942c54e613697f29babe56dde0dc1d18f69e7 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Thu, 3 Apr 2025 16:09:34 +0400 Subject: [PATCH 24/30] [fix] Use polling observer to make sure new file modifications are detected (#3316) --- aim/sdk/index_manager.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/aim/sdk/index_manager.py b/aim/sdk/index_manager.py index f7e01502b..ff8eaa723 100644 --- a/aim/sdk/index_manager.py +++ b/aim/sdk/index_manager.py @@ -10,6 +10,7 @@ from aim.sdk.repo import Repo from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer +from watchdog.observers.polling import PollingObserver logger = logging.getLogger(__name__) @@ -20,7 +21,7 @@ def __init__(self, manager): self.manager = manager def on_created(self, event): - if event.is_directory: + if event.is_directory and Path(event.src_path).parent == self.manager.chunks_dir: chunk_name = os.path.basename(event.src_path) logger.debug(f'Detected new chunk directory: {chunk_name}') self.manager.monitor_chunk_directory(event.src_path) @@ -89,13 +90,17 @@ def __init__(self, repo: Repo, disable_monitoring: bool): self.indexing_queue = queue.PriorityQueue() self.lock = threading.Lock() - self.observer = Observer() + self.new_chunk_observer = Observer() + self.chunk_change_observer = PollingObserver() + self.new_chunk_handler = NewChunkCreatedHandler(self) self.chunk_change_handler = ChunkChangedHandler(self) - self.observer.schedule(self.new_chunk_handler, self.chunks_dir, recursive=True) + self.new_chunk_observer.schedule(self.new_chunk_handler, self.chunks_dir, recursive=True) + self.new_chunk_observer.start() + self._monitor_existing_chunks() - self.observer.start() + self.chunk_change_observer.start() self._reindex_thread = threading.Thread(target=self._process_queue, daemon=True) self._reindex_thread.start() @@ -108,8 +113,8 @@ def _monitor_existing_chunks(self): def monitor_chunk_directory(self, chunk_path): """Ensure chunk directory is monitored using a single handler.""" - if str(chunk_path) not in self.observer._watches: - self.observer.schedule(self.chunk_change_handler, chunk_path, recursive=True) + if str(chunk_path) not in self.chunk_change_observer._watches: + self.chunk_change_observer.schedule(self.chunk_change_handler, chunk_path, recursive=True) logger.debug(f'Started monitoring chunk directory: {chunk_path}') else: logger.debug(f'Chunk directory already monitored: {chunk_path}') From 02bdcdd21203e36f6f315e5050f42db836d639f5 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:21:15 +0400 Subject: [PATCH 25/30] [feat] Mark stalled runs as finished (#3314) --- aim/cli/up/commands.py | 3 + aim/sdk/repo.py | 23 ++- aim/sdk/reporter/file_manager.py | 6 +- aim/sdk/run_status_manager.py | 95 ++++++++++++ aim/sdk/run_status_watcher.py | 9 +- aim/storage/arrayview.py | 12 +- aim/storage/artifacts/artifact_storage.py | 9 +- aim/storage/inmemorytreeview.py | 6 +- aim/storage/query.py | 3 +- aim/storage/rockscontainer.pyx | 12 +- aim/storage/structured/entities.py | 171 ++++++++++++++-------- aim/storage/treeview.py | 39 +++-- 12 files changed, 289 insertions(+), 99 deletions(-) create mode 100644 aim/sdk/run_status_manager.py diff --git a/aim/cli/up/commands.py b/aim/cli/up/commands.py index e294c4f9b..4775e0aa8 100644 --- a/aim/cli/up/commands.py +++ b/aim/cli/up/commands.py @@ -13,6 +13,7 @@ ) from aim.sdk.index_manager import RepoIndexManager from aim.sdk.repo import Repo +from aim.sdk.run_status_manager import RunStatusManager from aim.sdk.utils import clean_repo_path from aim.web.configs import ( AIM_ENV_MODE_KEY, @@ -124,6 +125,8 @@ def up( os.environ[AIM_PROFILER_KEY] = '1' RepoIndexManager.get_index_manager(repo_inst) + run_status_mng = RunStatusManager(repo_inst) + run_status_mng.start() try: server_cmd = build_uvicorn_command( 'aim.web.run:app', diff --git a/aim/sdk/repo.py b/aim/sdk/repo.py index b37838421..151a56f86 100644 --- a/aim/sdk/repo.py +++ b/aim/sdk/repo.py @@ -269,7 +269,9 @@ def get_version(cls, path: str): def is_remote_path(cls, path: str): return path.startswith('aim://') - def _get_container(self, name: str, read_only: bool, from_union: bool = False, skip_read_optimization: bool = False) -> Container: + def _get_container( + self, name: str, read_only: bool, from_union: bool = False, skip_read_optimization: bool = False + ) -> Container: # TODO [AT]: refactor get container/tree logic to make it more simple if self.read_only and not read_only: raise ValueError('Repo is read-only') @@ -317,11 +319,17 @@ def request_tree( read_only: bool, from_union: bool = False, # TODO maybe = True by default no_cache: bool = False, - skip_read_optimization: bool = False + skip_read_optimization: bool = False, ): if not self.is_remote_repo: - return self.request(name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache, - skip_read_optimization=skip_read_optimization).tree() + return self.request( + name, + sub, + read_only=read_only, + from_union=from_union, + no_cache=no_cache, + skip_read_optimization=skip_read_optimization, + ).tree() else: return ProxyTree(self._client, name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache) @@ -333,7 +341,7 @@ def request( read_only: bool, from_union: bool = False, # TODO maybe = True by default no_cache: bool = False, - skip_read_optimization: bool = False + skip_read_optimization: bool = False, ): container_config = ContainerConfig(name, sub, read_only) container_view = self.container_view_pool.get(container_config) @@ -344,8 +352,9 @@ def request( else: assert sub is not None path = os.path.join(name, 'chunks', sub) - container = self._get_container(path, read_only=True, from_union=from_union, - skip_read_optimization=skip_read_optimization) + container = self._get_container( + path, read_only=True, from_union=from_union, skip_read_optimization=skip_read_optimization + ) else: assert sub is not None path = os.path.join(name, 'chunks', sub) diff --git a/aim/sdk/reporter/file_manager.py b/aim/sdk/reporter/file_manager.py index 80c2d9a85..72633f084 100644 --- a/aim/sdk/reporter/file_manager.py +++ b/aim/sdk/reporter/file_manager.py @@ -10,10 +10,12 @@ class FileManager(object): @abstractmethod - def poll(self, pattern: str) -> Optional[str]: ... + def poll(self, pattern: str) -> Optional[str]: + ... @abstractmethod - def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): ... + def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): + ... class LocalFileManager(FileManager): diff --git a/aim/sdk/run_status_manager.py b/aim/sdk/run_status_manager.py new file mode 100644 index 000000000..71dc42eeb --- /dev/null +++ b/aim/sdk/run_status_manager.py @@ -0,0 +1,95 @@ +import time +import os +import datetime +import pytz +import threading +from pathlib import Path + +from typing import Iterable + +import aimrocks.errors + +from aim import Repo +from aim.sdk.run_status_watcher import Event + + +class RunStatusManager: + INDEXING_GRACE_PERIOD = 10 + + def __init__(self, repo: Repo, scan_interval: int = 60): + self.repo = repo + self.scan_interval = scan_interval + + self.progress_dir = Path(self.repo.path) / 'meta' / 'progress' + self.progress_dir.mkdir(parents=True, exist_ok=True) + + self.heartbeat_dir = Path(self.repo.path) / 'check_ins' + self.run_heartbeat_cache = {} + + self._stop_event = threading.Event() + self._monitor_thread = None + self._corrupted_runs = set() + + def start(self): + if not self._monitor_thread or not self._monitor_thread.is_alive(): + self._stop_event.clear() + self._monitor_thread = threading.Thread(target=self._run_forever, daemon=True) + self._monitor_thread.start() + + def stop(self): + self._stop_event.set() + if self._monitor_thread: + self._monitor_thread.join() + + def _run_forever(self): + while not self._stop_event.is_set(): + self.check_and_terminate_stalled_runs() + time.sleep(self.scan_interval) + + def _runs_with_progress(self) -> Iterable[str]: + runs_with_progress = filter(lambda x: x not in self._corrupted_runs, os.listdir(self.progress_dir)) + run_hashes = sorted(runs_with_progress, key=lambda r: os.path.getmtime(os.path.join(self.progress_dir, r))) + return run_hashes + + def check_and_terminate_stalled_runs(self): + for run_hash in self._runs_with_progress(): + if self._is_run_stalled(run_hash): + self._mark_run_as_terminated(run_hash) + + def _is_run_stalled(self, run_hash: str) -> bool: + stalled = False + + heartbeat_files = list(sorted(self.heartbeat_dir.glob(f'{run_hash}-*-progress-*-*'), reverse=True)) + if heartbeat_files: + latest_file = heartbeat_files[0].name + last_heartbeat = Event(latest_file) + + last_recorded_heartbeat = self.run_heartbeat_cache.get(run_hash) + if last_recorded_heartbeat is None: + # First time seeing a heartbeat for this run; store and move on + self.run_heartbeat_cache[run_hash] = last_heartbeat + elif last_heartbeat.idx > last_recorded_heartbeat.idx: + # Newer heartbeat arrived, so the run isn't stalled + self.run_heartbeat_cache[run_hash] = last_heartbeat + else: + # No new heartbeat event since last time; check if enough time passed + time_passed = time.time() - last_recorded_heartbeat.detected_epoch_time + if (last_recorded_heartbeat.next_event_in + RunStatusManager.INDEXING_GRACE_PERIOD) < time_passed: + stalled = True + else: + stalled = True + + return stalled + + def _mark_run_as_terminated(self, run_hash: str): + # TODO [AT]: Add run state handling once decided on terms (finished, terminated, aborted, etc.) + try: + meta_run_tree = self.repo.request_tree('meta', run_hash, read_only=False).subtree( + ('meta', 'chunks', run_hash) + ) + if meta_run_tree.get('end_time') is None: + meta_run_tree['end_time'] = datetime.datetime.now(pytz.utc).timestamp() + progress_path = self.progress_dir / run_hash + progress_path.unlink(missing_ok=True) + except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): + self._corrupted_runs.add(run_hash) diff --git a/aim/sdk/run_status_watcher.py b/aim/sdk/run_status_watcher.py index 422cbff12..ccf203bd5 100644 --- a/aim/sdk/run_status_watcher.py +++ b/aim/sdk/run_status_watcher.py @@ -83,13 +83,16 @@ def __init__(self, *, obj_idx: Optional[str] = None, rank: Optional[int] = None, self.message = message @abstractmethod - def is_sent(self): ... + def is_sent(self): + ... @abstractmethod - def update_last_sent(self): ... + def update_last_sent(self): + ... @abstractmethod - def get_msg_details(self): ... + def get_msg_details(self): + ... class StatusNotification(Notification): diff --git a/aim/storage/arrayview.py b/aim/storage/arrayview.py index 4694c1eab..2b9fd8954 100644 --- a/aim/storage/arrayview.py +++ b/aim/storage/arrayview.py @@ -9,7 +9,8 @@ class ArrayView: when index values are not important. """ - def __iter__(self) -> Iterator[Any]: ... + def __iter__(self) -> Iterator[Any]: + ... def keys(self) -> Iterator[int]: """Return sparse indices iterator. @@ -43,13 +44,16 @@ def items(self) -> Iterator[Tuple[int, Any]]: """ ... - def __len__(self) -> int: ... + def __len__(self) -> int: + ... - def __getitem__(self, idx: Union[int, slice]): ... + def __getitem__(self, idx: Union[int, slice]): + ... # TODO implement append - def __setitem__(self, idx: int, val: Any): ... + def __setitem__(self, idx: int, val: Any): + ... def sparse_list(self) -> Tuple[List[int], List[Any]]: """Get sparse indices and values as :obj:`list`s.""" diff --git a/aim/storage/artifacts/artifact_storage.py b/aim/storage/artifacts/artifact_storage.py index efa73cbd1..e0bab8934 100644 --- a/aim/storage/artifacts/artifact_storage.py +++ b/aim/storage/artifacts/artifact_storage.py @@ -7,10 +7,13 @@ def __init__(self, url: str): self.url = url @abstractmethod - def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False): ... + def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False): + ... @abstractmethod - def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None) -> str: ... + def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None) -> str: + ... @abstractmethod - def delete_artifact(self, artifact_path: str): ... + def delete_artifact(self, artifact_path: str): + ... diff --git a/aim/storage/inmemorytreeview.py b/aim/storage/inmemorytreeview.py index 7d02c347d..1ce208594 100644 --- a/aim/storage/inmemorytreeview.py +++ b/aim/storage/inmemorytreeview.py @@ -117,6 +117,8 @@ def iterlevel( def array(self, path: Union[AimObjectKey, AimObjectPath] = (), dtype: Any = None) -> TreeArrayView: return TreeArrayView(self.subtree(path), dtype=dtype) - def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... + def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: + ... - def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... + def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: + ... diff --git a/aim/storage/query.py b/aim/storage/query.py index 82de23657..f8fa81fbb 100644 --- a/aim/storage/query.py +++ b/aim/storage/query.py @@ -80,7 +80,8 @@ def __init__(self, expr: str): self.expr = expr @abstractmethod - def check(self, **params) -> bool: ... + def check(self, **params) -> bool: + ... def __call__(self, **params): return self.check(**params) diff --git a/aim/storage/rockscontainer.pyx b/aim/storage/rockscontainer.pyx index d21e28c32..e96fc4b42 100644 --- a/aim/storage/rockscontainer.pyx +++ b/aim/storage/rockscontainer.pyx @@ -35,6 +35,7 @@ class RocksAutoClean(AutoClean): super().__init__(instance) self._lock = None self._db = None + self._progress_path = None def _close(self): """ @@ -48,6 +49,9 @@ class RocksAutoClean(AutoClean): self._db = None self._lock.release() self._lock = None + if self._progress_path is not None: + self._progress_path.unlink(missing_ok=True) + self._progress_path = None if self._db is not None: self._db = None @@ -104,6 +108,7 @@ class RocksContainer(Container): if not self.read_only: progress_dir.mkdir(parents=True, exist_ok=True) self._progress_path.touch(exist_ok=True) + self._resources._progress_path = self._progress_path self.db # TODO check if Containers are reopenable @@ -159,16 +164,9 @@ class RocksContainer(Container): Store the collection of `(key, value)` records in the :obj:`Container` `index` for fast reads. """ - if not self._progress_path: - return - for k, v in self.items(): index[k] = v - if self._progress_path.exists(): - self._progress_path.unlink() - self._progress_path = None - def close(self): """Close all the resources.""" if self._resources is None: diff --git a/aim/storage/structured/entities.py b/aim/storage/structured/entities.py index 900c422ec..a43471ea7 100644 --- a/aim/storage/structured/entities.py +++ b/aim/storage/structured/entities.py @@ -13,224 +13,281 @@ class StructuredObject(ABC): @classmethod @abstractmethod - def fields(cls): ... + def fields(cls): + ... class Searchable(ABC, Generic[T]): @classmethod @abstractmethod - def find(cls, _id: str, **kwargs) -> Optional[T]: ... + def find(cls, _id: str, **kwargs) -> Optional[T]: + ... @classmethod @abstractmethod - def all(cls, **kwargs) -> Collection[T]: ... + def all(cls, **kwargs) -> Collection[T]: + ... @classmethod @abstractmethod - def search(cls, term: str, **kwargs) -> Collection[T]: ... + def search(cls, term: str, **kwargs) -> Collection[T]: + ... class Run(StructuredObject, Searchable['Run']): @property @abstractmethod - def hash(self) -> str: ... + def hash(self) -> str: + ... @property @abstractmethod - def name(self) -> Optional[str]: ... + def name(self) -> Optional[str]: + ... @name.setter @abstractmethod - def name(self, value: str): ... + def name(self, value: str): + ... @property @abstractmethod - def description(self) -> Optional[str]: ... + def description(self) -> Optional[str]: + ... @description.setter @abstractmethod - def description(self, value: str): ... + def description(self, value: str): + ... @property @abstractmethod - def archived(self) -> bool: ... + def archived(self) -> bool: + ... @archived.setter @abstractmethod - def archived(self, value: bool): ... + def archived(self, value: bool): + ... @property @abstractmethod - def experiment(self) -> Optional['Experiment']: ... + def experiment(self) -> Optional['Experiment']: + ... @experiment.setter @abstractmethod - def experiment(self, value: str): ... + def experiment(self, value: str): + ... @property @abstractmethod - def tags(self) -> TagCollection: ... + def tags(self) -> TagCollection: + ... @abstractmethod - def add_tag(self, value: str) -> 'Tag': ... + def add_tag(self, value: str) -> 'Tag': + ... @abstractmethod - def remove_tag(self, tag_name: str) -> bool: ... + def remove_tag(self, tag_name: str) -> bool: + ... @property @abstractmethod - def info(self) -> 'RunInfo': ... + def info(self) -> 'RunInfo': + ... class Experiment(StructuredObject, Searchable['Experiment']): @property @abstractmethod - def uuid(self) -> str: ... + def uuid(self) -> str: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @name.setter @abstractmethod - def name(self, value: str): ... + def name(self, value: str): + ... @property @abstractmethod - def description(self) -> Optional[str]: ... + def description(self) -> Optional[str]: + ... @description.setter @abstractmethod - def description(self, value: str): ... + def description(self, value: str): + ... @property @abstractmethod - def archived(self) -> bool: ... + def archived(self) -> bool: + ... @archived.setter @abstractmethod - def archived(self, value: bool): ... + def archived(self, value: bool): + ... @property @abstractmethod - def runs(self) -> RunCollection: ... + def runs(self) -> RunCollection: + ... class Tag(StructuredObject, Searchable['Tag']): @property @abstractmethod - def uuid(self) -> str: ... + def uuid(self) -> str: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @name.setter @abstractmethod - def name(self, value: str): ... + def name(self, value: str): + ... @property @abstractmethod - def color(self) -> str: ... + def color(self) -> str: + ... @color.setter @abstractmethod - def color(self, value: str): ... + def color(self, value: str): + ... @property @abstractmethod - def description(self) -> str: ... + def description(self) -> str: + ... @description.setter @abstractmethod - def description(self, value: str): ... + def description(self, value: str): + ... @property @abstractmethod - def archived(self) -> bool: ... + def archived(self) -> bool: + ... @archived.setter @abstractmethod - def archived(self, value: bool): ... + def archived(self, value: bool): + ... @property @abstractmethod - def runs(self) -> RunCollection: ... + def runs(self) -> RunCollection: + ... class Note(StructuredObject, Searchable['Note']): @property @abstractmethod - def id(self) -> int: ... + def id(self) -> int: + ... @property @abstractmethod - def content(self) -> str: ... + def content(self) -> str: + ... @content.setter @abstractmethod - def content(self, value: str): ... + def content(self, value: str): + ... @property @abstractmethod - def run(self) -> int: ... + def run(self) -> int: + ... class RunInfo(StructuredObject, Generic[T]): @property @abstractmethod - def last_notification_index(self) -> int: ... + def last_notification_index(self) -> int: + ... @last_notification_index.setter @abstractmethod - def last_notification_index(self, value: int): ... + def last_notification_index(self, value: int): + ... class ObjectFactory: @abstractmethod - def runs(self) -> RunCollection: ... + def runs(self) -> RunCollection: + ... @abstractmethod - def search_runs(self, term: str) -> RunCollection: ... + def search_runs(self, term: str) -> RunCollection: + ... @abstractmethod - def find_run(self, _id: str) -> Run: ... + def find_run(self, _id: str) -> Run: + ... @abstractmethod - def find_runs(self, ids: List[str]) -> List[Run]: ... + def find_runs(self, ids: List[str]) -> List[Run]: + ... @abstractmethod - def create_run(self, runhash: str) -> Run: ... + def create_run(self, runhash: str) -> Run: + ... @abstractmethod - def delete_run(self, runhash: str) -> bool: ... + def delete_run(self, runhash: str) -> bool: + ... @abstractmethod - def experiments(self) -> ExperimentCollection: ... + def experiments(self) -> ExperimentCollection: + ... @abstractmethod - def search_experiments(self, term: str) -> ExperimentCollection: ... + def search_experiments(self, term: str) -> ExperimentCollection: + ... @abstractmethod - def find_experiment(self, _id: str) -> Experiment: ... + def find_experiment(self, _id: str) -> Experiment: + ... @abstractmethod - def create_experiment(self, name: str) -> Experiment: ... + def create_experiment(self, name: str) -> Experiment: + ... @abstractmethod - def delete_experiment(self, _id: str) -> bool: ... + def delete_experiment(self, _id: str) -> bool: + ... @abstractmethod - def tags(self) -> TagCollection: ... + def tags(self) -> TagCollection: + ... @abstractmethod - def search_tags(self, term: str) -> TagCollection: ... + def search_tags(self, term: str) -> TagCollection: + ... @abstractmethod - def find_tag(self, _id: str) -> Tag: ... + def find_tag(self, _id: str) -> Tag: + ... @abstractmethod - def create_tag(self, name: str) -> Tag: ... + def create_tag(self, name: str) -> Tag: + ... @abstractmethod - def delete_tag(self, name: str) -> bool: ... + def delete_tag(self, name: str) -> bool: + ... diff --git a/aim/storage/treeview.py b/aim/storage/treeview.py index fc05a06f6..f80beff50 100644 --- a/aim/storage/treeview.py +++ b/aim/storage/treeview.py @@ -8,21 +8,26 @@ class TreeView: - def preload(self): ... + def preload(self): + ... - def finalize(self, index: 'TreeView'): ... + def finalize(self, index: 'TreeView'): + ... def subtree(self, path: Union[AimObjectKey, AimObjectPath]) -> 'TreeView': # Default to: return self.view(path, resolve=False) - def view(self, path: Union[AimObjectKey, AimObjectPath], resolve: bool = False): ... + def view(self, path: Union[AimObjectKey, AimObjectPath], resolve: bool = False): + ... - def make_array(self, path: Union[AimObjectKey, AimObjectPath] = ()): ... + def make_array(self, path: Union[AimObjectKey, AimObjectPath] = ()): + ... def collect( self, path: Union[AimObjectKey, AimObjectPath] = (), strict: bool = True, resolve_objects: bool = False - ) -> AimObject: ... + ) -> AimObject: + ... def __getitem__(self, path: Union[AimObjectKey, AimObjectPath]) -> AimObject: return self.collect(path) @@ -33,7 +38,8 @@ def get(self, path: Union[AimObjectKey, AimObjectPath] = (), default: Any = None except KeyError: return default - def __delitem__(self, path: Union[AimObjectKey, AimObjectPath]): ... + def __delitem__(self, path: Union[AimObjectKey, AimObjectPath]): + ... def set(self, path: Union[AimObjectKey, AimObjectPath], value: AimObject, strict: bool = True): self.__setitem__(path, value) @@ -45,18 +51,25 @@ def __setitem__(self, path: Union[AimObjectKey, AimObjectPath], value: AimObject def keys_eager( self, path: Union[AimObjectKey, AimObjectPath] = (), - ): ... + ): + ... def keys( self, path: Union[AimObjectKey, AimObjectPath] = (), level: int = None - ) -> Iterator[Union[AimObjectPath, AimObjectKey]]: ... + ) -> Iterator[Union[AimObjectPath, AimObjectKey]]: + ... - def items_eager(self, path: Union[AimObjectKey, AimObjectPath] = ()): ... + def items_eager(self, path: Union[AimObjectKey, AimObjectPath] = ()): + ... - def items(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> Iterator[Tuple[AimObjectKey, AimObject]]: ... + def items(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> Iterator[Tuple[AimObjectKey, AimObject]]: + ... - def array(self, path: Union[AimObjectKey, AimObjectPath] = (), dtype: Any = None) -> 'ArrayView': ... + def array(self, path: Union[AimObjectKey, AimObjectPath] = (), dtype: Any = None) -> 'ArrayView': + ... - def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... + def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: + ... - def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... + def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: + ... From 9ee40a256d40d9aa2361529444f525a9b6b33a8a Mon Sep 17 00:00:00 2001 From: Larissa Poghosyan <43134338+larissapoghosyan@users.noreply.github.com> Date: Tue, 8 Apr 2025 15:28:52 +0100 Subject: [PATCH 26/30] [fix] Aim web ui integration in jupyter/colab (#3319) * api endpoint /status is not implemented, but we can rely on status code for /projects * implement retrying with exponential backoff --- aim/cli/manager/manager.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aim/cli/manager/manager.py b/aim/cli/manager/manager.py index 99922e708..52385b9d4 100644 --- a/aim/cli/manager/manager.py +++ b/aim/cli/manager/manager.py @@ -33,18 +33,18 @@ def check_startup_success(): import requests server_path = 'http://{}:{}{}'.format(args['--host'], args['--port'], args['--base-path']) - status_api = f'{server_path}/api/projects/status' - retry_count = 5 - sleep_interval = 1 + status_api = f'{server_path}/api/projects/' + retry_count = 10 + sleep_interval = 0.1 for _ in range(retry_count): + time.sleep(sleep_interval) + sleep_interval *= 2 try: response = requests.get(status_api) if response.status_code == 200: return True except Exception: pass - sleep_interval += 1 - time.sleep(sleep_interval) return False From 6a559f382bbbb63da6ee69e210d514492fcae087 Mon Sep 17 00:00:00 2001 From: Albert Torosyan <32957250+alberttorosyan@users.noreply.github.com> Date: Wed, 30 Apr 2025 16:40:44 +0400 Subject: [PATCH 27/30] [fix] Fallback to union db if index is missing (#3317) Co-authored-by: mihran113 --- aim/cli/runs/commands.py | 4 +- aim/cli/storage/commands.py | 4 +- aim/cli/up/commands.py | 4 +- aim/ext/transport/handlers.py | 4 +- aim/sdk/base_run.py | 5 +- aim/sdk/index_manager.py | 140 +++++++++++++----- aim/sdk/repo.py | 146 +++++++----------- aim/sdk/reporter/file_manager.py | 6 +- aim/sdk/run.py | 14 +- aim/sdk/run_status_manager.py | 8 +- aim/sdk/run_status_watcher.py | 9 +- aim/sdk/sequence.py | 2 +- aim/sdk/uri_service.py | 10 +- aim/sdk/utils.py | 5 +- aim/storage/arrayview.py | 12 +- aim/storage/artifacts/artifact_storage.py | 9 +- aim/storage/inmemorytreeview.py | 6 +- aim/storage/query.py | 3 +- aim/storage/structured/entities.py | 171 ++++++++-------------- aim/storage/treeview.py | 39 ++--- aim/storage/treeviewproxy.py | 4 - aim/storage/union.pyx | 5 +- aim/web/api/projects/project.py | 1 - 23 files changed, 265 insertions(+), 346 deletions(-) diff --git a/aim/cli/runs/commands.py b/aim/cli/runs/commands.py index 26a4b58d6..1696b3209 100644 --- a/aim/cli/runs/commands.py +++ b/aim/cli/runs/commands.py @@ -192,10 +192,10 @@ def update_metrics(ctx, yes): if not confirmed: return - index_manager = RepoIndexManager.get_index_manager(repo, disable_monitoring=True) + index_manager = RepoIndexManager.get_index_manager(repo) hashes = repo.list_all_runs() for run_hash in tqdm.tqdm(hashes, desc='Updating runs', total=len(hashes)): - meta_tree = repo.request_tree('meta', run_hash, read_only=False, from_union=False) + meta_tree = repo.request_tree('meta', run_hash, read_only=False) meta_run_tree = meta_tree.subtree(('meta', 'chunks', run_hash)) try: # check if the Run has already been updated. diff --git a/aim/cli/storage/commands.py b/aim/cli/storage/commands.py index 32bfe01d5..3210f7a69 100644 --- a/aim/cli/storage/commands.py +++ b/aim/cli/storage/commands.py @@ -51,7 +51,7 @@ def to_3_11(ctx, hashes, yes): if not confirmed: return - index_manager = RepoIndexManager.get_index_manager(repo, disable_monitoring=True) + index_manager = RepoIndexManager.get_index_manager(repo) for run_hash in tqdm(matched_hashes): try: run = Run(run_hash, repo=repo) @@ -97,7 +97,7 @@ def restore_runs(ctx, hashes, yes): return remaining_runs = [] - index_manager = RepoIndexManager.get_index_manager(repo, disable_monitoring=True) + index_manager = RepoIndexManager.get_index_manager(repo) for run_hash in tqdm(matched_hashes): try: restore_run_backup(repo, run_hash) diff --git a/aim/cli/up/commands.py b/aim/cli/up/commands.py index 4775e0aa8..6ad5e6e73 100644 --- a/aim/cli/up/commands.py +++ b/aim/cli/up/commands.py @@ -124,7 +124,9 @@ def up( if profiler: os.environ[AIM_PROFILER_KEY] = '1' - RepoIndexManager.get_index_manager(repo_inst) + index_mng = RepoIndexManager.get_index_manager(repo_inst) + index_mng.start() + run_status_mng = RunStatusManager(repo_inst) run_status_mng.start() try: diff --git a/aim/ext/transport/handlers.py b/aim/ext/transport/handlers.py index 7915bc105..23c9985d2 100644 --- a/aim/ext/transport/handlers.py +++ b/aim/ext/transport/handlers.py @@ -51,14 +51,12 @@ def get_tree(**kwargs): name = kwargs['name'] sub = kwargs['sub'] read_only = kwargs['read_only'] - from_union = kwargs['from_union'] index = kwargs['index'] timeout = kwargs['timeout'] - no_cache = kwargs.get('no_cache', False) if index: return ResourceRef(repo._get_index_tree(name, timeout)) else: - return ResourceRef(repo.request_tree(name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache)) + return ResourceRef(repo.request_tree(name, sub, read_only=read_only)) def get_structured_run(hash_, read_only, created_at, **kwargs): diff --git a/aim/sdk/base_run.py b/aim/sdk/base_run.py index 89edf63b0..f77c435d8 100644 --- a/aim/sdk/base_run.py +++ b/aim/sdk/base_run.py @@ -39,6 +39,7 @@ def __init__( if self.read_only: assert run_hash is not None self.hash = run_hash + self.meta_tree: TreeView = self.repo.request_tree('meta', read_only=True).subtree('meta') else: if run_hash is None: self.hash = generate_run_hash() @@ -48,10 +49,8 @@ def __init__( raise MissingRunError(f'Cannot find Run {run_hash} in aim Repo {self.repo.path}.') self._lock = self.repo.request_run_lock(self.hash) self._lock.lock(force=force_resume) + self.meta_tree: TreeView = self.repo.request_tree('meta', self.hash, read_only=False).subtree('meta') - self.meta_tree: TreeView = self.repo.request_tree( - 'meta', self.hash, read_only=read_only, from_union=True - ).subtree('meta') self.meta_run_tree: TreeView = self.meta_tree.subtree('chunks').subtree(self.hash) self._series_run_trees: Dict[int, TreeView] = None diff --git a/aim/sdk/index_manager.py b/aim/sdk/index_manager.py index ff8eaa723..166e6ae0e 100644 --- a/aim/sdk/index_manager.py +++ b/aim/sdk/index_manager.py @@ -1,15 +1,19 @@ +import hashlib import logging import os import queue import threading +import time from pathlib import Path +from typing import Dict import aimrocks.errors from aim.sdk.repo import Repo from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer +from watchdog.observers.api import ObservedWatch from watchdog.observers.polling import PollingObserver @@ -19,12 +23,17 @@ class NewChunkCreatedHandler(FileSystemEventHandler): def __init__(self, manager): self.manager = manager + self.known_chunks = set(p.name for p in self.manager.chunks_dir.iterdir() if p.is_dir()) - def on_created(self, event): - if event.is_directory and Path(event.src_path).parent == self.manager.chunks_dir: - chunk_name = os.path.basename(event.src_path) - logger.debug(f'Detected new chunk directory: {chunk_name}') - self.manager.monitor_chunk_directory(event.src_path) + def on_modified(self, event): + if event.is_directory and Path(event.src_path) == self.manager.chunks_dir: + current_chunks = set(p.name for p in self.manager.chunks_dir.iterdir() if p.is_dir()) + new_chunks = current_chunks - self.known_chunks + for chunk_name in new_chunks: + chunk_path = self.manager.chunks_dir / chunk_name + logger.debug(f'Detected new chunk directory: {chunk_name}') + self.manager.monitor_chunk_directory(chunk_path) + self.known_chunks = current_chunks class ChunkChangedHandler(FileSystemEventHandler): @@ -71,14 +80,14 @@ class RepoIndexManager: index_manager_pool = {} @classmethod - def get_index_manager(cls, repo: Repo, disable_monitoring: bool = False): + def get_index_manager(cls, repo: Repo): mng = cls.index_manager_pool.get(repo.path, None) if mng is None: - mng = RepoIndexManager(repo, disable_monitoring) + mng = RepoIndexManager(repo) cls.index_manager_pool[repo.path] = mng return mng - def __init__(self, repo: Repo, disable_monitoring: bool): + def __init__(self, repo: Repo): self.repo_path = repo.path self.repo = repo self.chunks_dir = Path(self.repo_path) / 'meta' / 'chunks' @@ -86,35 +95,71 @@ def __init__(self, repo: Repo, disable_monitoring: bool): self._corrupted_runs = set() - if not disable_monitoring: - self.indexing_queue = queue.PriorityQueue() - self.lock = threading.Lock() + self.indexing_queue = queue.PriorityQueue() + self.lock = threading.Lock() + + self.new_chunk_observer = Observer() + self.chunk_change_observer = PollingObserver() + + self.new_chunk_handler = NewChunkCreatedHandler(self) + self.chunk_change_handler = ChunkChangedHandler(self) + self._watches: Dict[str, ObservedWatch] = dict() + self.new_chunk_observer.schedule(self.new_chunk_handler, self.chunks_dir, recursive=False) - self.new_chunk_observer = Observer() - self.chunk_change_observer = PollingObserver() + self._stop_event = threading.Event() + self._index_thread = None + self._monitor_thread = None - self.new_chunk_handler = NewChunkCreatedHandler(self) - self.chunk_change_handler = ChunkChangedHandler(self) + def start(self): + self._stop_event.clear() + self.new_chunk_observer.start() + self.chunk_change_observer.start() - self.new_chunk_observer.schedule(self.new_chunk_handler, self.chunks_dir, recursive=True) - self.new_chunk_observer.start() + if not self._index_thread or not self._index_thread.is_alive(): + self._index_thread = threading.Thread(target=self._process_indexing_queue, daemon=True) + self._index_thread.start() - self._monitor_existing_chunks() - self.chunk_change_observer.start() + if not self._monitor_thread or not self._monitor_thread.is_alive(): + self._monitor_thread = threading.Thread(target=self._monitor_existing_chunks, daemon=True) + self._monitor_thread.start() - self._reindex_thread = threading.Thread(target=self._process_queue, daemon=True) - self._reindex_thread.start() + def stop(self): + self._stop_event.set() + self.new_chunk_observer.stop() + self.chunk_change_observer.stop() + if self._monitor_thread: + self._monitor_thread.join() + if self._index_thread: + self._index_thread.join() def _monitor_existing_chunks(self): - for chunk_path in self.chunks_dir.iterdir(): - if chunk_path.is_dir(): - logger.debug(f'Monitoring existing chunk: {chunk_path}') - self.monitor_chunk_directory(chunk_path) + while not self._stop_event.is_set(): + index_db = self.repo.request_tree('meta', read_only=True) + monitored_chunks = set(self._watches.keys()) + for chunk_path in self.chunks_dir.iterdir(): + if ( + chunk_path.is_dir() + and chunk_path.name not in monitored_chunks + and self._is_run_index_outdated(chunk_path.name, index_db) + ): + logger.debug(f'Monitoring existing chunk: {chunk_path}') + self.monitor_chunk_directory(chunk_path) + logger.debug(f'Triggering indexing for run {chunk_path.name}') + self.add_run_to_queue(chunk_path.name) + self.repo.container_pool.clear() + time.sleep(5) + + def _stop_monitoring_chunk(self, run_hash): + watch = self._watches.pop(run_hash, None) + if watch: + self.chunk_change_observer.unschedule(watch) + logger.debug(f'Stopped monitoring chunk: {run_hash}') def monitor_chunk_directory(self, chunk_path): """Ensure chunk directory is monitored using a single handler.""" - if str(chunk_path) not in self.chunk_change_observer._watches: - self.chunk_change_observer.schedule(self.chunk_change_handler, chunk_path, recursive=True) + if chunk_path.name not in self._watches: + watch = self.chunk_change_observer.schedule(self.chunk_change_handler, chunk_path, recursive=True) + self._watches[chunk_path.name] = watch logger.debug(f'Started monitoring chunk directory: {chunk_path}') else: logger.debug(f'Chunk directory already monitored: {chunk_path}') @@ -127,8 +172,8 @@ def add_run_to_queue(self, run_hash): self.indexing_queue.put((timestamp, run_hash)) logger.debug(f'Run {run_hash} added to indexing queue with timestamp {timestamp}') - def _process_queue(self): - while True: + def _process_indexing_queue(self): + while not self._stop_event.is_set(): _, run_hash = self.indexing_queue.get() logger.debug(f'Indexing run {run_hash}...') self.index(run_hash) @@ -137,12 +182,41 @@ def _process_queue(self): def index(self, run_hash): index = self.repo._get_index_tree('meta', 0).view(()) try: - meta_tree = self.repo.request_tree( - 'meta', run_hash, read_only=True, from_union=False, no_cache=True, skip_read_optimization=True - ).subtree('meta') + run_checksum = self._get_run_checksum(run_hash) + meta_tree = self.repo.request_tree('meta', run_hash, read_only=True, skip_read_optimization=True).subtree( + 'meta' + ) meta_run_tree = meta_tree.subtree('chunks').subtree(run_hash) meta_run_tree.finalize(index=index) + index['index_cache', run_hash] = run_checksum + + if meta_run_tree.get('end_time') is not None: + logger.debug(f'Indexing thread detected finished run: {run_hash}. Stopping monitoring...') + self._stop_monitoring_chunk(run_hash) + except (aimrocks.errors.RocksIOError, aimrocks.errors.Corruption): - logger.warning(f"Indexing thread detected corrupted run '{run_hash}'. Skipping.") + logger.warning(f'Indexing thread detected corrupted run: {run_hash}. Skipping.') self._corrupted_runs.add(run_hash) return True + + def _is_run_index_outdated(self, run_hash, index_db): + return self._get_run_checksum(run_hash) != index_db.get(('index_cache', run_hash)) + + def _get_run_checksum(self, run_hash): + hash_obj = hashlib.md5() + + for root, dirs, files in os.walk(os.path.join(self.chunks_dir, run_hash)): + for name in sorted(files): # sort to ensure consistent order + if name.startswith('LOG'): # skip access logs + continue + filepath = os.path.join(root, name) + try: + stat = os.stat(filepath) + hash_obj.update(filepath.encode('utf-8')) + hash_obj.update(str(stat.st_mtime).encode('utf-8')) + hash_obj.update(str(stat.st_size).encode('utf-8')) + except FileNotFoundError: + # File might have been deleted between os.walk and os.stat + continue + + return hash_obj.hexdigest() diff --git a/aim/sdk/repo.py b/aim/sdk/repo.py index 151a56f86..1ffef1c9b 100644 --- a/aim/sdk/repo.py +++ b/aim/sdk/repo.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple from weakref import WeakValueDictionary +import aimrocks.errors + from aim.ext.cleanup import AutoClean from aim.ext.sshfs.utils import mount_remote_repo, unmount_remote_repo from aim.ext.task_queue.queue import TaskQueue @@ -127,9 +129,14 @@ def __init__(self, path: str, *, read_only: Optional[bool] = None, init: Optiona self.root_path = path self.path = os.path.join(self.root_path, get_aim_repo_name()) - if init: + if init and not self.is_remote_repo: os.makedirs(self.path, exist_ok=True) os.makedirs(os.path.join(self.path, 'locks'), exist_ok=True) + + # Make sure meta index db is created + path = os.path.join(self.path, 'meta', 'index') + RocksContainer(path, read_only=False) + if not self.is_remote_repo and not os.path.exists(self.path): if self._mount_root: unmount_remote_repo(self.root_path, self._mount_root) @@ -137,7 +144,6 @@ def __init__(self, path: str, *, read_only: Optional[bool] = None, init: Optiona self.container_pool: Dict[ContainerConfig, Container] = WeakValueDictionary() self.persistent_pool: Dict[ContainerConfig, Container] = dict() - self.container_view_pool: Dict[ContainerConfig, Container] = WeakValueDictionary() self._run_props_cache_hint = None self._encryption_key = None @@ -160,7 +166,7 @@ def __init__(self, path: str, *, read_only: Optional[bool] = None, init: Optiona @property def meta_tree(self): - return self.request_tree('meta', read_only=True, from_union=True).subtree('meta') + return self.request_tree('meta', read_only=True).subtree('meta') def __repr__(self) -> str: return f'' @@ -269,28 +275,6 @@ def get_version(cls, path: str): def is_remote_path(cls, path: str): return path.startswith('aim://') - def _get_container( - self, name: str, read_only: bool, from_union: bool = False, skip_read_optimization: bool = False - ) -> Container: - # TODO [AT]: refactor get container/tree logic to make it more simple - if self.read_only and not read_only: - raise ValueError('Repo is read-only') - - container_config = ContainerConfig(name, None, read_only=read_only) - container = self.container_pool.get(container_config) - if container is None: - if from_union: - # Temporarily use index db when getting data from union. - path = os.path.join(self.path, name, 'index') - container = RocksContainer(path, read_only=read_only, skip_read_optimization=skip_read_optimization) - self.persistent_pool[container_config] = container - else: - path = os.path.join(self.path, name) - container = RocksContainer(path, read_only=read_only, skip_read_optimization=skip_read_optimization) - self.container_pool[container_config] = container - - return container - def _get_index_tree(self, name: str, timeout: int): if not self.is_remote_repo: return self._get_index_container(name, timeout).tree() @@ -311,60 +295,30 @@ def _get_index_container(self, name: str, timeout: int) -> Container: return container - def request_tree( - self, - name: str, - sub: str = None, - *, - read_only: bool, - from_union: bool = False, # TODO maybe = True by default - no_cache: bool = False, - skip_read_optimization: bool = False, - ): + def request_tree(self, name: str, sub: str = None, *, read_only: bool, skip_read_optimization: bool = False): if not self.is_remote_repo: - return self.request( - name, - sub, - read_only=read_only, - from_union=from_union, - no_cache=no_cache, - skip_read_optimization=skip_read_optimization, + return self.request_container( + name, sub, read_only=read_only, skip_read_optimization=skip_read_optimization ).tree() else: - return ProxyTree(self._client, name, sub, read_only=read_only, from_union=from_union, no_cache=no_cache) + return ProxyTree(self._client, name, sub, read_only=read_only) - def request( - self, - name: str, - sub: str = None, - *, - read_only: bool, - from_union: bool = False, # TODO maybe = True by default - no_cache: bool = False, - skip_read_optimization: bool = False, - ): + def request_container(self, name: str, sub: str = None, *, read_only: bool, skip_read_optimization: bool = False): container_config = ContainerConfig(name, sub, read_only) - container_view = self.container_view_pool.get(container_config) - if container_view is None or no_cache: - if read_only: - if from_union: - path = name - else: - assert sub is not None - path = os.path.join(name, 'chunks', sub) - container = self._get_container( - path, read_only=True, from_union=from_union, skip_read_optimization=skip_read_optimization - ) + container = self.container_pool.get(container_config) + if container is None: + if sub is None: + try: + path = os.path.join(self.path, name, 'index') + container = RocksContainer(path, read_only=True, skip_read_optimization=skip_read_optimization) + except aimrocks.errors.RocksIOError: + path = os.path.join(self.path, name) + container = RocksUnionContainer(path, read_only=True) else: - assert sub is not None - path = os.path.join(name, 'chunks', sub) - container = self._get_container(path, read_only=False, from_union=False) - - container_view = container - if not no_cache: - self.container_view_pool[container_config] = container_view - - return container_view + path = os.path.join(self.path, name, 'chunks', sub) + container = RocksContainer(path, read_only=read_only, skip_read_optimization=skip_read_optimization) + self.container_pool[container_config] = container + return container def request_props(self, hash_: str, read_only: bool, created_at: 'datetime' = None): if self.is_remote_repo: @@ -755,9 +709,6 @@ def encryption_key(self): return encryption_key - def _get_meta_tree(self): - return self.request_tree('meta', read_only=True, from_union=True).subtree('meta') - @staticmethod def available_sequence_types(): return Sequence.registry.keys() @@ -779,7 +730,6 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di Returns: :obj:`dict`: Tree of sequences and their contexts groupped by sequence type. """ - meta_tree = self._get_meta_tree() sequence_traces = {} if isinstance(sequence_types, str): sequence_types = (sequence_types,) @@ -792,7 +742,7 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di dtype_traces = set() for dtype in dtypes: try: - dtype_trace_tree = meta_tree.collect(('traces_types', dtype)) + dtype_trace_tree = self.meta_tree.collect(('traces_types', dtype)) for ctx_id, seqs in dtype_trace_tree.items(): for seq_name in seqs.keys(): dtype_traces.add((ctx_id, seq_name)) @@ -800,7 +750,7 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di pass if 'float' in dtypes: # old sequences without dtype set are considered float sequences try: - dtype_trace_tree = meta_tree.collect('traces') + dtype_trace_tree = self.meta_tree.collect('traces') for ctx_id, seqs in dtype_trace_tree.items(): for seq_name in seqs.keys(): dtype_traces.add((ctx_id, seq_name)) @@ -808,7 +758,7 @@ def collect_sequence_info(self, sequence_types: Tuple[str, ...]) -> Dict[str, Di pass traces_info = defaultdict(list) for ctx_id, seq_name in dtype_traces: - traces_info[seq_name].append(meta_tree['contexts', ctx_id]) + traces_info[seq_name].append(self.meta_tree['contexts', ctx_id]) sequence_traces[seq_type] = traces_info return sequence_traces @@ -818,9 +768,8 @@ def collect_params_info(self) -> dict: Returns: :obj:`dict`: All runs meta-parameters. """ - meta_tree = self._get_meta_tree() try: - return meta_tree.collect('attrs', strict=False) + return self.meta_tree.collect('attrs', strict=False) except KeyError: return {} @@ -891,22 +840,13 @@ def _delete_run(self, run_hash): def _copy_run(self, run_hash, dest_repo): def copy_trees(): # copy run meta tree - source_meta_tree = self.request_tree( - 'meta', run_hash, read_only=True, from_union=False, no_cache=True - ).subtree('meta') - dest_meta_tree = dest_repo.request_tree( - 'meta', run_hash, read_only=False, from_union=False, no_cache=True - ).subtree('meta') - dest_meta_run_tree = dest_meta_tree.subtree('chunks').subtree(run_hash) + source_meta_tree = self.request_tree('meta', run_hash, read_only=True).subtree('meta') + dest_meta_tree = dest_repo.request_tree('meta', run_hash, read_only=False).subtree('meta') dest_meta_tree[...] = source_meta_tree[...] - dest_index = dest_repo._get_index_tree('meta', timeout=10).view(()) - dest_meta_run_tree.finalize(index=dest_index) # copy run series tree - source_series_run_tree = self.request_tree('seqs', run_hash, read_only=True, no_cache=True).subtree('seqs') - dest_series_run_tree = dest_repo.request_tree('seqs', run_hash, read_only=False, no_cache=True).subtree( - 'seqs' - ) + source_series_run_tree = self.request_tree('seqs', run_hash, read_only=True).subtree('seqs') + dest_series_run_tree = dest_repo.request_tree('seqs', run_hash, read_only=False).subtree('seqs') # copy v2 sequences source_v2_tree = source_series_run_tree.subtree(('v2', 'chunks', run_hash)) @@ -1014,6 +954,10 @@ def _restore_run(self, run_hash): restore_run_backup(self, run_hash) def _close_run(self, run_hash): + import datetime + + import pytz + def optimize_container(path, extra_options): rc = RocksContainer(path, read_only=True, **extra_options) rc.optimize_for_read() @@ -1024,6 +968,16 @@ def optimize_container(path, extra_options): lock_manager = LockManager(self.path) if lock_manager.release_locks(run_hash, force=True): + # Set run end time if locks are removed + meta_tree = self.request_tree( + 'meta', + run_hash, + read_only=False, + ).subtree('meta') + meta_run_tree = meta_tree.subtree('chunks').subtree(run_hash) + if not meta_run_tree.get('end_time'): + meta_run_tree['end_time'] = datetime.datetime.now(pytz.utc).timestamp() + # Run rocksdb optimizations if container locks are removed meta_db_path = os.path.join(self.path, 'meta', 'chunks', run_hash) seqs_db_path = os.path.join(self.path, 'seqs', 'chunks', run_hash) @@ -1039,7 +993,7 @@ def _recreate_index(self): from aim.sdk.index_manager import RepoIndexManager - index_manager = RepoIndexManager.get_index_manager(self, disable_monitoring=True) + index_manager = RepoIndexManager.get_index_manager(self) # force delete the index db and the locks diff --git a/aim/sdk/reporter/file_manager.py b/aim/sdk/reporter/file_manager.py index 72633f084..80c2d9a85 100644 --- a/aim/sdk/reporter/file_manager.py +++ b/aim/sdk/reporter/file_manager.py @@ -10,12 +10,10 @@ class FileManager(object): @abstractmethod - def poll(self, pattern: str) -> Optional[str]: - ... + def poll(self, pattern: str) -> Optional[str]: ... @abstractmethod - def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): - ... + def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): ... class LocalFileManager(FileManager): diff --git a/aim/sdk/run.py b/aim/sdk/run.py index 775aed973..b53bdf72a 100644 --- a/aim/sdk/run.py +++ b/aim/sdk/run.py @@ -82,9 +82,9 @@ def __init__(self, instance: 'Run') -> None: def add_extra_resource(self, resource) -> None: self.extra_resources.append(resource) - def finalize_run(self): + def set_run_end_time(self): """ - Finalize the run by indexing all the data. + Set Run end_time to mark it as finished. """ self.meta_run_tree['end_time'] = datetime.datetime.now(pytz.utc).timestamp() @@ -94,7 +94,7 @@ def empty_rpc_queue(self): def _close(self) -> None: """ - Close the `Run` instance resources and trigger indexing. + Close the `Run` instance resources. """ if self.read_only: logger.debug(f'Run {self.hash} is read-only, skipping cleanup') @@ -104,7 +104,7 @@ def _close(self) -> None: res.close() self.empty_rpc_queue() - self.finalize_run() + self.set_run_end_time() if self._heartbeat is not None: self._heartbeat.stop() if self._checkins is not None: @@ -725,12 +725,6 @@ def close(self): self._props = None self._cleanup_trees() - def finalize(self): - if self._resources is None: - return - - self._resources.finalize_run() - def dataframe( self, include_props: bool = True, diff --git a/aim/sdk/run_status_manager.py b/aim/sdk/run_status_manager.py index 71dc42eeb..e1fa6f3fc 100644 --- a/aim/sdk/run_status_manager.py +++ b/aim/sdk/run_status_manager.py @@ -1,13 +1,13 @@ -import time -import os import datetime -import pytz +import os import threading -from pathlib import Path +import time +from pathlib import Path from typing import Iterable import aimrocks.errors +import pytz from aim import Repo from aim.sdk.run_status_watcher import Event diff --git a/aim/sdk/run_status_watcher.py b/aim/sdk/run_status_watcher.py index ccf203bd5..422cbff12 100644 --- a/aim/sdk/run_status_watcher.py +++ b/aim/sdk/run_status_watcher.py @@ -83,16 +83,13 @@ def __init__(self, *, obj_idx: Optional[str] = None, rank: Optional[int] = None, self.message = message @abstractmethod - def is_sent(self): - ... + def is_sent(self): ... @abstractmethod - def update_last_sent(self): - ... + def update_last_sent(self): ... @abstractmethod - def get_msg_details(self): - ... + def get_msg_details(self): ... class StatusNotification(Notification): diff --git a/aim/sdk/sequence.py b/aim/sdk/sequence.py index de8c78e1d..dde9e215f 100644 --- a/aim/sdk/sequence.py +++ b/aim/sdk/sequence.py @@ -201,7 +201,7 @@ def numpy(self) -> Tuple[np.ndarray, List[np.ndarray]]: sort_indices = steps.argsort() columns = [arr[sort_indices] for arr in columns] steps = steps[sort_indices] - if last_step is not None and last_step != steps[-1]: + if last_step is not None and last_step > steps[-1]: step_hash = self.step_hash(last_step) # The `last_step` is provided by the meta tree which may potentially # be out of sync with the series tree. diff --git a/aim/sdk/uri_service.py b/aim/sdk/uri_service.py index 10d588918..062c05ac6 100644 --- a/aim/sdk/uri_service.py +++ b/aim/sdk/uri_service.py @@ -55,7 +55,7 @@ def request_batch(self, uri_batch: List[str]) -> Iterator[Dict[str, bytes]]: for uri, sub_name, resource_path in self.runs_pool[run_name]: container = run_containers.get(sub_name) if not container: - container = self._get_container(run_name, sub_name) + container = self.repo.request_container(sub_name, run_name, read_only=True) run_containers[sub_name] = container resource_path = decode_path(bytes.fromhex(resource_path)) @@ -70,11 +70,3 @@ def request_batch(self, uri_batch: List[str]) -> Iterator[Dict[str, bytes]]: # clear runs pool self.runs_pool.clear() - - def _get_container(self, run_name: str, sub_name: str): - if sub_name == 'meta': - container = self.repo.request(sub_name, run_name, from_union=True, read_only=True) - else: - container = self.repo.request(sub_name, run_name, read_only=True) - - return container diff --git a/aim/sdk/utils.py b/aim/sdk/utils.py index 6863600f9..0e5ff84fa 100644 --- a/aim/sdk/utils.py +++ b/aim/sdk/utils.py @@ -165,13 +165,12 @@ def flatten(d, parent_path=None): return all_paths subtrees_to_lookup = ('attrs', 'traces_types', 'contexts', 'traces') - repo_meta_tree = repo._get_meta_tree() # set of all repo paths that can be left dangling after run deletion repo_paths = set() for key in subtrees_to_lookup: try: - repo_paths.update(flatten(repo_meta_tree.collect(key, strict=False), parent_path=(key,))) + repo_paths.update(flatten(repo.meta_tree.collect(key, strict=False), parent_path=(key,))) except KeyError: pass @@ -179,7 +178,7 @@ def flatten(d, parent_path=None): for run_hash in tqdm(run_hashes): # construct unique paths set for each run run_paths = set() - run_meta_tree = repo.request_tree('meta', run_hash, from_union=False, read_only=True).subtree('meta') + run_meta_tree = repo.request_tree('meta', run_hash, read_only=True).subtree('meta') for key in subtrees_to_lookup: try: run_paths.update(flatten(run_meta_tree.collect(key, strict=False), parent_path=(key,))) diff --git a/aim/storage/arrayview.py b/aim/storage/arrayview.py index 2b9fd8954..4694c1eab 100644 --- a/aim/storage/arrayview.py +++ b/aim/storage/arrayview.py @@ -9,8 +9,7 @@ class ArrayView: when index values are not important. """ - def __iter__(self) -> Iterator[Any]: - ... + def __iter__(self) -> Iterator[Any]: ... def keys(self) -> Iterator[int]: """Return sparse indices iterator. @@ -44,16 +43,13 @@ def items(self) -> Iterator[Tuple[int, Any]]: """ ... - def __len__(self) -> int: - ... + def __len__(self) -> int: ... - def __getitem__(self, idx: Union[int, slice]): - ... + def __getitem__(self, idx: Union[int, slice]): ... # TODO implement append - def __setitem__(self, idx: int, val: Any): - ... + def __setitem__(self, idx: int, val: Any): ... def sparse_list(self) -> Tuple[List[int], List[Any]]: """Get sparse indices and values as :obj:`list`s.""" diff --git a/aim/storage/artifacts/artifact_storage.py b/aim/storage/artifacts/artifact_storage.py index e0bab8934..efa73cbd1 100644 --- a/aim/storage/artifacts/artifact_storage.py +++ b/aim/storage/artifacts/artifact_storage.py @@ -7,13 +7,10 @@ def __init__(self, url: str): self.url = url @abstractmethod - def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False): - ... + def upload_artifact(self, file_path: str, artifact_path: str, block: bool = False): ... @abstractmethod - def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None) -> str: - ... + def download_artifact(self, artifact_path: str, dest_dir: Optional[str] = None) -> str: ... @abstractmethod - def delete_artifact(self, artifact_path: str): - ... + def delete_artifact(self, artifact_path: str): ... diff --git a/aim/storage/inmemorytreeview.py b/aim/storage/inmemorytreeview.py index 1ce208594..7d02c347d 100644 --- a/aim/storage/inmemorytreeview.py +++ b/aim/storage/inmemorytreeview.py @@ -117,8 +117,6 @@ def iterlevel( def array(self, path: Union[AimObjectKey, AimObjectPath] = (), dtype: Any = None) -> TreeArrayView: return TreeArrayView(self.subtree(path), dtype=dtype) - def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: - ... + def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... - def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: - ... + def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... diff --git a/aim/storage/query.py b/aim/storage/query.py index f8fa81fbb..82de23657 100644 --- a/aim/storage/query.py +++ b/aim/storage/query.py @@ -80,8 +80,7 @@ def __init__(self, expr: str): self.expr = expr @abstractmethod - def check(self, **params) -> bool: - ... + def check(self, **params) -> bool: ... def __call__(self, **params): return self.check(**params) diff --git a/aim/storage/structured/entities.py b/aim/storage/structured/entities.py index a43471ea7..900c422ec 100644 --- a/aim/storage/structured/entities.py +++ b/aim/storage/structured/entities.py @@ -13,281 +13,224 @@ class StructuredObject(ABC): @classmethod @abstractmethod - def fields(cls): - ... + def fields(cls): ... class Searchable(ABC, Generic[T]): @classmethod @abstractmethod - def find(cls, _id: str, **kwargs) -> Optional[T]: - ... + def find(cls, _id: str, **kwargs) -> Optional[T]: ... @classmethod @abstractmethod - def all(cls, **kwargs) -> Collection[T]: - ... + def all(cls, **kwargs) -> Collection[T]: ... @classmethod @abstractmethod - def search(cls, term: str, **kwargs) -> Collection[T]: - ... + def search(cls, term: str, **kwargs) -> Collection[T]: ... class Run(StructuredObject, Searchable['Run']): @property @abstractmethod - def hash(self) -> str: - ... + def hash(self) -> str: ... @property @abstractmethod - def name(self) -> Optional[str]: - ... + def name(self) -> Optional[str]: ... @name.setter @abstractmethod - def name(self, value: str): - ... + def name(self, value: str): ... @property @abstractmethod - def description(self) -> Optional[str]: - ... + def description(self) -> Optional[str]: ... @description.setter @abstractmethod - def description(self, value: str): - ... + def description(self, value: str): ... @property @abstractmethod - def archived(self) -> bool: - ... + def archived(self) -> bool: ... @archived.setter @abstractmethod - def archived(self, value: bool): - ... + def archived(self, value: bool): ... @property @abstractmethod - def experiment(self) -> Optional['Experiment']: - ... + def experiment(self) -> Optional['Experiment']: ... @experiment.setter @abstractmethod - def experiment(self, value: str): - ... + def experiment(self, value: str): ... @property @abstractmethod - def tags(self) -> TagCollection: - ... + def tags(self) -> TagCollection: ... @abstractmethod - def add_tag(self, value: str) -> 'Tag': - ... + def add_tag(self, value: str) -> 'Tag': ... @abstractmethod - def remove_tag(self, tag_name: str) -> bool: - ... + def remove_tag(self, tag_name: str) -> bool: ... @property @abstractmethod - def info(self) -> 'RunInfo': - ... + def info(self) -> 'RunInfo': ... class Experiment(StructuredObject, Searchable['Experiment']): @property @abstractmethod - def uuid(self) -> str: - ... + def uuid(self) -> str: ... @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @name.setter @abstractmethod - def name(self, value: str): - ... + def name(self, value: str): ... @property @abstractmethod - def description(self) -> Optional[str]: - ... + def description(self) -> Optional[str]: ... @description.setter @abstractmethod - def description(self, value: str): - ... + def description(self, value: str): ... @property @abstractmethod - def archived(self) -> bool: - ... + def archived(self) -> bool: ... @archived.setter @abstractmethod - def archived(self, value: bool): - ... + def archived(self, value: bool): ... @property @abstractmethod - def runs(self) -> RunCollection: - ... + def runs(self) -> RunCollection: ... class Tag(StructuredObject, Searchable['Tag']): @property @abstractmethod - def uuid(self) -> str: - ... + def uuid(self) -> str: ... @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @name.setter @abstractmethod - def name(self, value: str): - ... + def name(self, value: str): ... @property @abstractmethod - def color(self) -> str: - ... + def color(self) -> str: ... @color.setter @abstractmethod - def color(self, value: str): - ... + def color(self, value: str): ... @property @abstractmethod - def description(self) -> str: - ... + def description(self) -> str: ... @description.setter @abstractmethod - def description(self, value: str): - ... + def description(self, value: str): ... @property @abstractmethod - def archived(self) -> bool: - ... + def archived(self) -> bool: ... @archived.setter @abstractmethod - def archived(self, value: bool): - ... + def archived(self, value: bool): ... @property @abstractmethod - def runs(self) -> RunCollection: - ... + def runs(self) -> RunCollection: ... class Note(StructuredObject, Searchable['Note']): @property @abstractmethod - def id(self) -> int: - ... + def id(self) -> int: ... @property @abstractmethod - def content(self) -> str: - ... + def content(self) -> str: ... @content.setter @abstractmethod - def content(self, value: str): - ... + def content(self, value: str): ... @property @abstractmethod - def run(self) -> int: - ... + def run(self) -> int: ... class RunInfo(StructuredObject, Generic[T]): @property @abstractmethod - def last_notification_index(self) -> int: - ... + def last_notification_index(self) -> int: ... @last_notification_index.setter @abstractmethod - def last_notification_index(self, value: int): - ... + def last_notification_index(self, value: int): ... class ObjectFactory: @abstractmethod - def runs(self) -> RunCollection: - ... + def runs(self) -> RunCollection: ... @abstractmethod - def search_runs(self, term: str) -> RunCollection: - ... + def search_runs(self, term: str) -> RunCollection: ... @abstractmethod - def find_run(self, _id: str) -> Run: - ... + def find_run(self, _id: str) -> Run: ... @abstractmethod - def find_runs(self, ids: List[str]) -> List[Run]: - ... + def find_runs(self, ids: List[str]) -> List[Run]: ... @abstractmethod - def create_run(self, runhash: str) -> Run: - ... + def create_run(self, runhash: str) -> Run: ... @abstractmethod - def delete_run(self, runhash: str) -> bool: - ... + def delete_run(self, runhash: str) -> bool: ... @abstractmethod - def experiments(self) -> ExperimentCollection: - ... + def experiments(self) -> ExperimentCollection: ... @abstractmethod - def search_experiments(self, term: str) -> ExperimentCollection: - ... + def search_experiments(self, term: str) -> ExperimentCollection: ... @abstractmethod - def find_experiment(self, _id: str) -> Experiment: - ... + def find_experiment(self, _id: str) -> Experiment: ... @abstractmethod - def create_experiment(self, name: str) -> Experiment: - ... + def create_experiment(self, name: str) -> Experiment: ... @abstractmethod - def delete_experiment(self, _id: str) -> bool: - ... + def delete_experiment(self, _id: str) -> bool: ... @abstractmethod - def tags(self) -> TagCollection: - ... + def tags(self) -> TagCollection: ... @abstractmethod - def search_tags(self, term: str) -> TagCollection: - ... + def search_tags(self, term: str) -> TagCollection: ... @abstractmethod - def find_tag(self, _id: str) -> Tag: - ... + def find_tag(self, _id: str) -> Tag: ... @abstractmethod - def create_tag(self, name: str) -> Tag: - ... + def create_tag(self, name: str) -> Tag: ... @abstractmethod - def delete_tag(self, name: str) -> bool: - ... + def delete_tag(self, name: str) -> bool: ... diff --git a/aim/storage/treeview.py b/aim/storage/treeview.py index f80beff50..fc05a06f6 100644 --- a/aim/storage/treeview.py +++ b/aim/storage/treeview.py @@ -8,26 +8,21 @@ class TreeView: - def preload(self): - ... + def preload(self): ... - def finalize(self, index: 'TreeView'): - ... + def finalize(self, index: 'TreeView'): ... def subtree(self, path: Union[AimObjectKey, AimObjectPath]) -> 'TreeView': # Default to: return self.view(path, resolve=False) - def view(self, path: Union[AimObjectKey, AimObjectPath], resolve: bool = False): - ... + def view(self, path: Union[AimObjectKey, AimObjectPath], resolve: bool = False): ... - def make_array(self, path: Union[AimObjectKey, AimObjectPath] = ()): - ... + def make_array(self, path: Union[AimObjectKey, AimObjectPath] = ()): ... def collect( self, path: Union[AimObjectKey, AimObjectPath] = (), strict: bool = True, resolve_objects: bool = False - ) -> AimObject: - ... + ) -> AimObject: ... def __getitem__(self, path: Union[AimObjectKey, AimObjectPath]) -> AimObject: return self.collect(path) @@ -38,8 +33,7 @@ def get(self, path: Union[AimObjectKey, AimObjectPath] = (), default: Any = None except KeyError: return default - def __delitem__(self, path: Union[AimObjectKey, AimObjectPath]): - ... + def __delitem__(self, path: Union[AimObjectKey, AimObjectPath]): ... def set(self, path: Union[AimObjectKey, AimObjectPath], value: AimObject, strict: bool = True): self.__setitem__(path, value) @@ -51,25 +45,18 @@ def __setitem__(self, path: Union[AimObjectKey, AimObjectPath], value: AimObject def keys_eager( self, path: Union[AimObjectKey, AimObjectPath] = (), - ): - ... + ): ... def keys( self, path: Union[AimObjectKey, AimObjectPath] = (), level: int = None - ) -> Iterator[Union[AimObjectPath, AimObjectKey]]: - ... + ) -> Iterator[Union[AimObjectPath, AimObjectKey]]: ... - def items_eager(self, path: Union[AimObjectKey, AimObjectPath] = ()): - ... + def items_eager(self, path: Union[AimObjectKey, AimObjectPath] = ()): ... - def items(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> Iterator[Tuple[AimObjectKey, AimObject]]: - ... + def items(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> Iterator[Tuple[AimObjectKey, AimObject]]: ... - def array(self, path: Union[AimObjectKey, AimObjectPath] = (), dtype: Any = None) -> 'ArrayView': - ... + def array(self, path: Union[AimObjectKey, AimObjectPath] = (), dtype: Any = None) -> 'ArrayView': ... - def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: - ... + def first_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... - def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: - ... + def last_key(self, path: Union[AimObjectKey, AimObjectPath] = ()) -> AimObjectKey: ... diff --git a/aim/storage/treeviewproxy.py b/aim/storage/treeviewproxy.py index f459a096a..d2e188e84 100644 --- a/aim/storage/treeviewproxy.py +++ b/aim/storage/treeviewproxy.py @@ -24,8 +24,6 @@ def __init__( sub: str, *, read_only: bool, - from_union: bool = False, - no_cache: bool = False, index=False, timeout=None, ): @@ -38,10 +36,8 @@ def __init__( 'name': name, 'sub': sub, 'read_only': read_only, - 'from_union': from_union, 'index': index, 'timeout': timeout, - 'no_cache': no_cache, } self.init_args = pack_args(encode_tree(kwargs)) self.resource_type = 'TreeView' diff --git a/aim/storage/union.pyx b/aim/storage/union.pyx index 2d5729c75..e9bafc577 100644 --- a/aim/storage/union.pyx +++ b/aim/storage/union.pyx @@ -242,11 +242,8 @@ class DB(object): index_db = None logger.info('No index was detected') - # If index exists -- only load those in progress - selector = 'progress' if index_db is not None else 'chunks' - new_dbs: Dict[bytes, aimrocks.DB] = {} - db_dir = os.path.join(self.db_path, self.db_name, selector) + db_dir = os.path.join(self.db_path, self.db_name, 'chunks') for prefix in self._list_dir(db_dir): path = os.path.join(self.db_path, self.db_name, "chunks", prefix) prefix = encode_path((self.db_name, "chunks", prefix)) diff --git a/aim/web/api/projects/project.py b/aim/web/api/projects/project.py index 2fa29ee7a..b1ae57eba 100644 --- a/aim/web/api/projects/project.py +++ b/aim/web/api/projects/project.py @@ -18,7 +18,6 @@ def __init__(self): def cleanup_repo_pools(self): self.repo.container_pool.clear() - self.repo.container_view_pool.clear() self.repo.persistent_pool.clear() def cleanup_sql_caches(self): From a1a233b55aa6849079c3e34b356eeb59bfec72c6 Mon Sep 17 00:00:00 2001 From: Albert Torosyan Date: Thu, 8 May 2025 11:19:02 +0400 Subject: [PATCH 28/30] Bump up Aim to v3.29.0 --- CHANGELOG.md | 9 ++++++++- aim/VERSION | 2 +- aim/web/ui/package.json | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d55c7567e..6518d276c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,16 @@ # Changelog -## Unreleased: +## 3.29.0 May 8, 2025: + +### Enhancements: +- Constant indexing of in-progress runs (alberttorosyan) +- Fallback to union view if index db is missing (alberttorosyan, mihran113) + ### Fixes: - Fix min/max calculation for single point metrics (mihran113) +- Aim web ui integration in jupyter/colab (larissapoghosyan) + ## 3.28.0 Mar 21, 2025 diff --git a/aim/VERSION b/aim/VERSION index a72fd67b6..c7c977326 100644 --- a/aim/VERSION +++ b/aim/VERSION @@ -1 +1 @@ -3.28.0 +3.29.0 diff --git a/aim/web/ui/package.json b/aim/web/ui/package.json index c9a9976a0..d5519f277 100644 --- a/aim/web/ui/package.json +++ b/aim/web/ui/package.json @@ -1,6 +1,6 @@ { "name": "ui_v2", - "version": "3.28.0", + "version": "3.29.0", "private": true, "dependencies": { "@aksel/structjs": "^1.0.0", From 753f4b18437b8288e1c6f7c894c14a33cba9e7d0 Mon Sep 17 00:00:00 2001 From: Albert Torosyan Date: Thu, 8 May 2025 13:42:47 +0400 Subject: [PATCH 29/30] Bump up Aim to v3.29.1 --- .github/workflows/python-package.yml | 1 + CHANGELOG.md | 4 +++- aim/VERSION | 2 +- aim/web/ui/package.json | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c414f5536..f8116d54a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -96,6 +96,7 @@ jobs: python -m pip install -r requirements.txt - name: Build bdist wheels for 'cp37-cp37m' + if: matrix.manylinux-version == 'manylinux_2_24_x86_64' uses: nick-fields/retry@v2 with: max_attempts: 3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 6518d276c..f3b114a0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## 3.29.0 May 8, 2025: +## 3.29.1 May 8, 2025: ### Enhancements: - Constant indexing of in-progress runs (alberttorosyan) @@ -10,7 +10,9 @@ ### Fixes: - Fix min/max calculation for single point metrics (mihran113) - Aim web ui integration in jupyter/colab (larissapoghosyan) +- Package publishing for Linux/Python 3.7 (alberttorosyan) +## 3.29.0 May 8, 2025 (Yanked) ## 3.28.0 Mar 21, 2025 diff --git a/aim/VERSION b/aim/VERSION index c7c977326..1002be7fb 100644 --- a/aim/VERSION +++ b/aim/VERSION @@ -1 +1 @@ -3.29.0 +3.29.1 diff --git a/aim/web/ui/package.json b/aim/web/ui/package.json index d5519f277..99ebb2bb8 100644 --- a/aim/web/ui/package.json +++ b/aim/web/ui/package.json @@ -1,6 +1,6 @@ { "name": "ui_v2", - "version": "3.29.0", + "version": "3.29.1", "private": true, "dependencies": { "@aksel/structjs": "^1.0.0", From d67e7663fad36ba57705723d46d16ef0c6240007 Mon Sep 17 00:00:00 2001 From: mihran113 Date: Thu, 26 Jun 2025 18:30:56 +0400 Subject: [PATCH 30/30] [fix] Resolve issues with false tag reassignment (#3344) --- CHANGELOG.md | 5 ++ aim/sdk/data_version.py | 2 +- .../migrations/versions/661514b12ee1_.py | 69 +++++++++++++++++++ aim/storage/structured/db.py | 3 +- aim/storage/structured/sql_engine/entities.py | 8 +-- aim/storage/structured/sql_engine/models.py | 8 ++- 6 files changed, 85 insertions(+), 10 deletions(-) create mode 100644 aim/storage/migrations/versions/661514b12ee1_.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f3b114a0a..bf3ba2d6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased: + +### Fixes: +- Fix issues with tag false reassignment (mihran113) + ## 3.29.1 May 8, 2025: ### Enhancements: diff --git a/aim/sdk/data_version.py b/aim/sdk/data_version.py index 55f4f52d6..4c496ddb5 100644 --- a/aim/sdk/data_version.py +++ b/aim/sdk/data_version.py @@ -1 +1 @@ -DATA_VERSION = (1, 3) +DATA_VERSION = (1, 4) diff --git a/aim/storage/migrations/versions/661514b12ee1_.py b/aim/storage/migrations/versions/661514b12ee1_.py new file mode 100644 index 000000000..eacfccfe8 --- /dev/null +++ b/aim/storage/migrations/versions/661514b12ee1_.py @@ -0,0 +1,69 @@ +"""empty message + +Revision ID: 661514b12ee1 +Revises: 46b89d830ad8 +Create Date: 2025-06-05 19:52:31.221392 + +""" +from alembic import op +import sqlalchemy as sa +from alembic.context import get_context + + +# revision identifiers, used by Alembic. +revision = '661514b12ee1' +down_revision = '46b89d830ad8' +branch_labels = None +depends_on = None + + + +def upgrade(): + # Get the SQLite connection context + context = get_context() + naming_convention = { + "fk": + "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + } + # Use batch operations for SQLite + with op.batch_alter_table('run_tag', naming_convention=naming_convention) as batch_op: + # First drop the existing foreign key + batch_op.drop_constraint('fk_run_tag_run_id_run', type_='foreignkey') + batch_op.drop_constraint('fk_run_tag_tag_id_tag', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_run_tag_run_id_run', 'run', ['run_id'], ['id'], ondelete='CASCADE') + batch_op.create_foreign_key('fk_run_tag_tag_id_tag', 'tag', ['tag_id'], ['id'], ondelete='CASCADE') + + + with op.batch_alter_table('note', naming_convention=naming_convention) as batch_op: + # First drop the existing foreign key + batch_op.drop_constraint('fk_note_run_id_run', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_note_run_id_run', 'run', ['run_id'], ['id'], ondelete='CASCADE') + + +def downgrade(): + # Use batch operations for SQLite + naming_convention = { + "fk": + "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + } + # Use batch operations for SQLite + with op.batch_alter_table('run_tag', naming_convention=naming_convention) as batch_op: + # Drop the CASCADE foreign key + batch_op.drop_constraint('fk_run_tag_run_id_run', type_='foreignkey') + batch_op.drop_constraint('fk_run_tag_tag_id_tag', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_run_tag_run_id_run', 'run', ['run_id'], ['id'],) + batch_op.create_foreign_key('fk_run_tag_tag_id_tag', 'tag', ['tag_id'], ['id'],) + + with op.batch_alter_table('note', naming_convention=naming_convention) as batch_op: + # First drop the existing foreign key + batch_op.drop_constraint('fk_note_run_id_run', type_='foreignkey') + + # Then create a new one with CASCADE + batch_op.create_foreign_key('fk_note_run_id_run', 'run', ['run_id'], ['id'],) + diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index 830c0bc41..cf0087a57 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -9,7 +9,7 @@ ) from aim.storage.types import SafeNone from aim.web.configs import AIM_LOG_LEVEL_KEY -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event from sqlalchemy.orm import scoped_session, sessionmaker @@ -66,6 +66,7 @@ def __init__(self, path: str, readonly: bool = False): pool_size=10, max_overflow=20, ) + event.listen(self.engine, 'connect', lambda c, _: c.execute('pragma foreign_keys=on')) self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) self._upgraded = None diff --git a/aim/storage/structured/sql_engine/entities.py b/aim/storage/structured/sql_engine/entities.py index 554d4c70d..21d08626f 100644 --- a/aim/storage/structured/sql_engine/entities.py +++ b/aim/storage/structured/sql_engine/entities.py @@ -87,11 +87,9 @@ def from_hash(cls, runhash: str, created_at, session) -> 'ModelMappedRun': @classmethod def delete_run(cls, runhash: str, session) -> bool: - try: - rows_affected = session.query(RunModel).filter(RunModel.hash == runhash).delete() - session_commit_or_flush(session) - except Exception: - return False + rows_affected = session.query(RunModel).filter(RunModel.hash == runhash).delete() + session_commit_or_flush(session) + return rows_affected > 0 @classmethod diff --git a/aim/storage/structured/sql_engine/models.py b/aim/storage/structured/sql_engine/models.py index 1c78c539e..9859a8d85 100644 --- a/aim/storage/structured/sql_engine/models.py +++ b/aim/storage/structured/sql_engine/models.py @@ -28,7 +28,7 @@ def default_to_run_hash(context): run_tags = Table( 'run_tag', Base.metadata, - Column('run_id', Integer, ForeignKey('run.id'), primary_key=True, nullable=False), + Column('run_id', Integer, ForeignKey('run.id', ondelete='CASCADE'), primary_key=True, nullable=False), Column('tag_id', Integer, ForeignKey('tag.id'), primary_key=True, nullable=False), ) @@ -51,7 +51,9 @@ class Run(Base): experiment_id = Column(ForeignKey('experiment.id'), nullable=True) experiment = relationship('Experiment', backref=backref('runs', uselist=True, order_by='Run.created_at.desc()')) - tags = relationship('Tag', secondary=run_tags, backref=backref('runs', uselist=True)) + tags = relationship( + 'Tag', secondary=run_tags, backref=backref('runs', uselist=True), cascade='all, delete', passive_deletes=True + ) notes = relationship('Note', back_populates='run') def __init__(self, run_hash, created_at=None): @@ -106,7 +108,7 @@ class Note(Base): id = Column(Integer, autoincrement=True, primary_key=True) content = Column(Text, nullable=False, default='') - run_id = Column(Integer, ForeignKey('run.id')) + run_id = Column(Integer, ForeignKey('run.id', ondelete='CASCADE'),) experiment_id = Column(Integer, ForeignKey('experiment.id')) created_at = Column(DateTime, default=datetime.datetime.utcnow)