From 25342a8e763f4b492a446a8c9105ca3be17d7aa3 Mon Sep 17 00:00:00 2001 From: James Collins Date: Fri, 27 Jun 2025 16:00:07 -0700 Subject: [PATCH 1/2] Add branch for sqllite and postgres --- aim/storage/migrations/env.py | 3 ++- aim/storage/structured/db.py | 51 ++++++++++++++++++++++------------- aim/web/api/db.py | 15 ++++++++--- 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/aim/storage/migrations/env.py b/aim/storage/migrations/env.py index be4e13695..eba39c234 100644 --- a/aim/storage/migrations/env.py +++ b/aim/storage/migrations/env.py @@ -9,7 +9,8 @@ from alembic.config import Config from sqlalchemy import create_engine -import aim.storage.drop_table_cascade # noqa: F401 +if os.environ.get("AIM_USE_PG", False): + import aim.storage.drop_table_cascade # noqa: F401 # this is the Alembic Config object, which provides diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index a58b9a9ec..cb5503ffb 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -12,7 +12,8 @@ from sqlalchemy import create_engine, event from sqlalchemy.orm import scoped_session, sessionmaker -import aim.storage.drop_table_cascade # noqa: F401 +if os.environ.get("AIM_USE_PG", False): + import aim.storage.drop_table_cascade # noqa: F401 class ObjectCache: def __init__(self, data_fetch_func, key_func): @@ -47,8 +48,6 @@ def __getitem__(self, key): class DB(ObjectFactory): - _DB_NAME = 'app' - _DEFAULT_PORT = 5432 _pool = WeakValueDictionary() _caches = dict() @@ -57,17 +56,25 @@ class DB(ObjectFactory): def __init__(self, path: str, readonly: bool = False): import logging - super().__init__() - pg_dbname = os.environ['AIM_PG_DBNAME_RUNS'] - self.path = pg_dbname - self.db_url = self.get_db_url(self.path) + super().__init__() + if os.environ.get("AIM_USE_PG", False): + self.path = os.environ['AIM_PG_DBNAME_RUNS'] + engine_options = { + "pool_pre_ping": True, + } + else: + self.path = path + engine_options = { + "pool_size": 10, + "max_overflow": 20, + } + + self.db_url = self.get_db_url(self.path) self.readonly = readonly self.engine = create_engine( self.db_url, echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), - pool_pre_ping=True - # pool_size=10, - # max_overflow=20, + **engine_options, ) event.listen(self.engine, 'connect', lambda c, _: c.execute('pragma foreign_keys=on')) self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) @@ -82,18 +89,26 @@ def from_path(cls, path: str, readonly: bool = False): return db @staticmethod - def get_default_url(): - pg_dbname = os.environ['AIM_PG_DBNAME_RUNS'] - return DB.get_db_url(pg_dbname) + def get_default_url(): + return DB.get_db_url(".aim") @staticmethod def get_db_url(path: str) -> str: - pg_user = os.environ['AIM_PG_USER'] - pg_password = os.environ['AIM_PG_PASSWORD'] - pg_host = os.environ['AIM_PG_HOST'] - pg_port = os.environ['AIM_PG_PORT'] + if os.environ.get("AIM_USE_PG", False): + pg_dbname = os.environ['AIM_PG_DBNAME_RUNS'] + pg_user = os.environ['AIM_PG_USER'] + pg_password = os.environ['AIM_PG_PASSWORD'] + pg_host = os.environ['AIM_PG_HOST'] + pg_port = os.environ['AIM_PG_PORT'] + db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}" + else: + db_dialect = "sqlite" + db_name = "run_metadata.sqlite" + if os.path.exists(path): + db_url = f'{db_dialect}:///{path}/{db_name}' + else: + raise RuntimeError(f'Cannot find database {path}. Please init first.') - db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{path}" return db_url @property diff --git a/aim/web/api/db.py b/aim/web/api/db.py index 562c840ca..ec09305fe 100644 --- a/aim/web/api/db.py +++ b/aim/web/api/db.py @@ -9,14 +9,21 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -import aim.storage.drop_table_cascade # noqa: F401 +if os.environ.get("AIM_USE_PG", False): + import aim.storage.drop_table_cascade # noqa: F401 + engine_options = {} +else: + engine_options = { + "connect_args": {"check_same_thread": False}, + "pool_size": 10, + "max_overflow": 20, + } + engine = create_engine( get_db_url(), echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), - # connect_args={'check_same_thread': False}, - # pool_size=10, - # max_overflow=20, + **engine_options, ) SessionLocal = sessionmaker(autoflush=False, bind=engine) From 9cb33c5ec511f9e926fa30f4fff24a5d6641d8d8 Mon Sep 17 00:00:00 2001 From: James Collins Date: Mon, 30 Jun 2025 08:14:14 -0700 Subject: [PATCH 2/2] Fix missing branch in web utils --- aim/web/utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/aim/web/utils.py b/aim/web/utils.py index 140e79d78..0f38aa4ab 100644 --- a/aim/web/utils.py +++ b/aim/web/utils.py @@ -40,11 +40,14 @@ def get_root_path(): def get_db_url(): - pg_user = os.environ['AIM_PG_USER'] - pg_password = os.environ['AIM_PG_PASSWORD'] - pg_host = os.environ['AIM_PG_HOST'] - pg_port = os.environ['AIM_PG_PORT'] - pg_dbname = os.environ['AIM_PG_DBNAME_WEB'] - - db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}" - return db_url + if os.environ.get("AIM_USE_PG", False): + pg_user = os.environ['AIM_PG_USER'] + pg_password = os.environ['AIM_PG_PASSWORD'] + pg_host = os.environ['AIM_PG_HOST'] + pg_port = os.environ['AIM_PG_PORT'] + pg_dbname = os.environ['AIM_PG_DBNAME_WEB'] + + return f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}" + else: + return 'sqlite:///{}/{}/aim_db'.format(get_root_path(), get_aim_repo_name()) +