Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aim/storage/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from alembic.config import Config
from sqlalchemy import create_engine

import aim.storage.drop_table_cascade # noqa: F401
if os.environ.get("AIM_USE_PG", False):
import aim.storage.drop_table_cascade # noqa: F401


# this is the Alembic Config object, which provides
Expand Down
51 changes: 33 additions & 18 deletions aim/storage/structured/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from sqlalchemy import create_engine, event
from sqlalchemy.orm import scoped_session, sessionmaker

import aim.storage.drop_table_cascade # noqa: F401
if os.environ.get("AIM_USE_PG", False):
import aim.storage.drop_table_cascade # noqa: F401

class ObjectCache:
def __init__(self, data_fetch_func, key_func):
Expand Down Expand Up @@ -47,8 +48,6 @@ def __getitem__(self, key):


class DB(ObjectFactory):
_DB_NAME = 'app'
_DEFAULT_PORT = 5432
_pool = WeakValueDictionary()

_caches = dict()
Expand All @@ -57,17 +56,25 @@ class DB(ObjectFactory):
def __init__(self, path: str, readonly: bool = False):
import logging

super().__init__()
pg_dbname = os.environ['AIM_PG_DBNAME_RUNS']
self.path = pg_dbname
self.db_url = self.get_db_url(self.path)
super().__init__()
if os.environ.get("AIM_USE_PG", False):
self.path = os.environ['AIM_PG_DBNAME_RUNS']
engine_options = {
"pool_pre_ping": True,
}
else:
self.path = path
engine_options = {
"pool_size": 10,
"max_overflow": 20,
}

self.db_url = self.get_db_url(self.path)
self.readonly = readonly
self.engine = create_engine(
self.db_url,
echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))),
pool_pre_ping=True
# pool_size=10,
# max_overflow=20,
**engine_options,
)
event.listen(self.engine, 'connect', lambda c, _: c.execute('pragma foreign_keys=on'))
self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine))
Expand All @@ -82,18 +89,26 @@ def from_path(cls, path: str, readonly: bool = False):
return db

@staticmethod
def get_default_url():
pg_dbname = os.environ['AIM_PG_DBNAME_RUNS']
return DB.get_db_url(pg_dbname)
def get_default_url():
return DB.get_db_url(".aim")

@staticmethod
def get_db_url(path: str) -> str:
pg_user = os.environ['AIM_PG_USER']
pg_password = os.environ['AIM_PG_PASSWORD']
pg_host = os.environ['AIM_PG_HOST']
pg_port = os.environ['AIM_PG_PORT']
if os.environ.get("AIM_USE_PG", False):
pg_dbname = os.environ['AIM_PG_DBNAME_RUNS']
pg_user = os.environ['AIM_PG_USER']
pg_password = os.environ['AIM_PG_PASSWORD']
pg_host = os.environ['AIM_PG_HOST']
pg_port = os.environ['AIM_PG_PORT']
db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}"
else:
db_dialect = "sqlite"
db_name = "run_metadata.sqlite"
if os.path.exists(path):
db_url = f'{db_dialect}:///{path}/{db_name}'
else:
raise RuntimeError(f'Cannot find database {path}. Please init first.')

db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{path}"
return db_url

@property
Expand Down
15 changes: 11 additions & 4 deletions aim/web/api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

import aim.storage.drop_table_cascade # noqa: F401
if os.environ.get("AIM_USE_PG", False):
import aim.storage.drop_table_cascade # noqa: F401
engine_options = {}
else:
engine_options = {
"connect_args": {"check_same_thread": False},
"pool_size": 10,
"max_overflow": 20,
}


engine = create_engine(
get_db_url(),
echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))),
# connect_args={'check_same_thread': False},
# pool_size=10,
# max_overflow=20,
**engine_options,
)

SessionLocal = sessionmaker(autoflush=False, bind=engine)
Expand Down
19 changes: 11 additions & 8 deletions aim/web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def get_root_path():


def get_db_url():
pg_user = os.environ['AIM_PG_USER']
pg_password = os.environ['AIM_PG_PASSWORD']
pg_host = os.environ['AIM_PG_HOST']
pg_port = os.environ['AIM_PG_PORT']
pg_dbname = os.environ['AIM_PG_DBNAME_WEB']

db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}"
return db_url
if os.environ.get("AIM_USE_PG", False):
pg_user = os.environ['AIM_PG_USER']
pg_password = os.environ['AIM_PG_PASSWORD']
pg_host = os.environ['AIM_PG_HOST']
pg_port = os.environ['AIM_PG_PORT']
pg_dbname = os.environ['AIM_PG_DBNAME_WEB']

return f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}"
else:
return 'sqlite:///{}/{}/aim_db'.format(get_root_path(), get_aim_repo_name())