diff --git a/aim/storage/migrations/env.py b/aim/storage/migrations/env.py index be4e13695..eba39c234 100644 --- a/aim/storage/migrations/env.py +++ b/aim/storage/migrations/env.py @@ -9,7 +9,8 @@ from alembic.config import Config from sqlalchemy import create_engine -import aim.storage.drop_table_cascade # noqa: F401 +if os.environ.get("AIM_USE_PG", False): + import aim.storage.drop_table_cascade # noqa: F401 # this is the Alembic Config object, which provides diff --git a/aim/storage/structured/db.py b/aim/storage/structured/db.py index a58b9a9ec..cb5503ffb 100644 --- a/aim/storage/structured/db.py +++ b/aim/storage/structured/db.py @@ -12,7 +12,8 @@ from sqlalchemy import create_engine, event from sqlalchemy.orm import scoped_session, sessionmaker -import aim.storage.drop_table_cascade # noqa: F401 +if os.environ.get("AIM_USE_PG", False): + import aim.storage.drop_table_cascade # noqa: F401 class ObjectCache: def __init__(self, data_fetch_func, key_func): @@ -47,8 +48,6 @@ def __getitem__(self, key): class DB(ObjectFactory): - _DB_NAME = 'app' - _DEFAULT_PORT = 5432 _pool = WeakValueDictionary() _caches = dict() @@ -57,17 +56,25 @@ class DB(ObjectFactory): def __init__(self, path: str, readonly: bool = False): import logging - super().__init__() - pg_dbname = os.environ['AIM_PG_DBNAME_RUNS'] - self.path = pg_dbname - self.db_url = self.get_db_url(self.path) + super().__init__() + if os.environ.get("AIM_USE_PG", False): + self.path = os.environ['AIM_PG_DBNAME_RUNS'] + engine_options = { + "pool_pre_ping": True, + } + else: + self.path = path + engine_options = { + "pool_size": 10, + "max_overflow": 20, + } + + self.db_url = self.get_db_url(self.path) self.readonly = readonly self.engine = create_engine( self.db_url, echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), - pool_pre_ping=True - # pool_size=10, - # max_overflow=20, + **engine_options, ) event.listen(self.engine, 'connect', lambda c, _: c.execute('pragma foreign_keys=on')) self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine)) @@ -82,18 +89,26 @@ def from_path(cls, path: str, readonly: bool = False): return db @staticmethod - def get_default_url(): - pg_dbname = os.environ['AIM_PG_DBNAME_RUNS'] - return DB.get_db_url(pg_dbname) + def get_default_url(): + return DB.get_db_url(".aim") @staticmethod def get_db_url(path: str) -> str: - pg_user = os.environ['AIM_PG_USER'] - pg_password = os.environ['AIM_PG_PASSWORD'] - pg_host = os.environ['AIM_PG_HOST'] - pg_port = os.environ['AIM_PG_PORT'] + if os.environ.get("AIM_USE_PG", False): + pg_dbname = os.environ['AIM_PG_DBNAME_RUNS'] + pg_user = os.environ['AIM_PG_USER'] + pg_password = os.environ['AIM_PG_PASSWORD'] + pg_host = os.environ['AIM_PG_HOST'] + pg_port = os.environ['AIM_PG_PORT'] + db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}" + else: + db_dialect = "sqlite" + db_name = "run_metadata.sqlite" + if os.path.exists(path): + db_url = f'{db_dialect}:///{path}/{db_name}' + else: + raise RuntimeError(f'Cannot find database {path}. Please init first.') - db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{path}" return db_url @property diff --git a/aim/web/api/db.py b/aim/web/api/db.py index 562c840ca..ec09305fe 100644 --- a/aim/web/api/db.py +++ b/aim/web/api/db.py @@ -9,14 +9,21 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -import aim.storage.drop_table_cascade # noqa: F401 +if os.environ.get("AIM_USE_PG", False): + import aim.storage.drop_table_cascade # noqa: F401 + engine_options = {} +else: + engine_options = { + "connect_args": {"check_same_thread": False}, + "pool_size": 10, + "max_overflow": 20, + } + engine = create_engine( get_db_url(), echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))), - # connect_args={'check_same_thread': False}, - # pool_size=10, - # max_overflow=20, + **engine_options, ) SessionLocal = sessionmaker(autoflush=False, bind=engine) diff --git a/aim/web/utils.py b/aim/web/utils.py index 140e79d78..0f38aa4ab 100644 --- a/aim/web/utils.py +++ b/aim/web/utils.py @@ -40,11 +40,14 @@ def get_root_path(): def get_db_url(): - pg_user = os.environ['AIM_PG_USER'] - pg_password = os.environ['AIM_PG_PASSWORD'] - pg_host = os.environ['AIM_PG_HOST'] - pg_port = os.environ['AIM_PG_PORT'] - pg_dbname = os.environ['AIM_PG_DBNAME_WEB'] - - db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}" - return db_url + if os.environ.get("AIM_USE_PG", False): + pg_user = os.environ['AIM_PG_USER'] + pg_password = os.environ['AIM_PG_PASSWORD'] + pg_host = os.environ['AIM_PG_HOST'] + pg_port = os.environ['AIM_PG_PORT'] + pg_dbname = os.environ['AIM_PG_DBNAME_WEB'] + + return f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}" + else: + return 'sqlite:///{}/{}/aim_db'.format(get_root_path(), get_aim_repo_name()) +