Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ class DatabaseConfig(BaseSettings):
instance_name: str = Field(
description="The name of the database instance", validation_alias="PGAPPNAME"
)
endpoint_name: str | None = Field(
description="The name of the Lakebase autoscaling endpoint (new postgres API)",
default=None,
validation_alias="PGENDPOINT",
)

@property
def use_autoscaling(self) -> bool:
"""Whether to use the new Lakebase autoscaling (postgres) API."""
return self.endpoint_name is not None


# --- Engine creation ---
Expand All @@ -60,15 +70,22 @@ def _build_engine_url(
return f"postgresql+psycopg://{username}:{password}@localhost:{dev_port}/postgres?sslmode=disable"

# Production mode: use Databricks Database
logger.info(f"Using Databricks database instance: {db_config.instance_name}")
instance = ws.database.get_database_instance(db_config.instance_name)
prefix = "postgresql+psycopg"
host = instance.read_write_dns
port = db_config.port
database = db_config.database_name
username = (
ws.config.client_id if ws.config.client_id else ws.current_user.me().user_name
)

if db_config.use_autoscaling:
logger.info(f"Using Lakebase autoscaling endpoint: {db_config.endpoint_name}")
endpoint = ws.postgres.get_endpoint(db_config.endpoint_name)
host = endpoint.status.hosts.host
else:
logger.info(f"Using Databricks database instance: {db_config.instance_name}")
instance = ws.database.get_database_instance(db_config.instance_name)
host = instance.read_write_dns

return f"{prefix}://{username}:@{host}:{port}/{database}"


Expand All @@ -90,9 +107,14 @@ def create_db_engine(db_config: DatabaseConfig, ws: WorkspaceClient) -> Engine:
engine = create_engine(engine_url, **engine_kwargs)

def before_connect(dialect, conn_rec, cargs, cparams):
cred = ws.database.generate_database_credential(
instance_names=[db_config.instance_name]
)
if db_config.use_autoscaling:
cred = ws.postgres.generate_database_credential(
endpoint=db_config.endpoint_name
)
else:
cred = ws.database.generate_database_credential(
instance_names=[db_config.instance_name]
)
cparams["password"] = cred.token

if not dev_port:
Expand All @@ -107,6 +129,17 @@ def validate_db(engine: Engine, db_config: DatabaseConfig) -> None:

if dev_port:
logger.info(f"Validating local dev database connection at localhost:{dev_port}")
elif db_config.use_autoscaling:
logger.info(
f"Validating Lakebase autoscaling endpoint {db_config.endpoint_name}"
)
try:
ws = WorkspaceClient()
ws.postgres.get_endpoint(db_config.endpoint_name)
except NotFound:
raise ValueError(
f"Lakebase endpoint {db_config.endpoint_name} does not exist"
)
else:
logger.info(
f"Validating database connection to instance {db_config.instance_name}"
Expand All @@ -128,6 +161,10 @@ def validate_db(engine: Engine, db_config: DatabaseConfig) -> None:

if dev_port:
logger.info("Local dev database connection validated successfully")
elif db_config.use_autoscaling:
logger.info(
f"Lakebase autoscaling endpoint {db_config.endpoint_name} validated successfully"
)
else:
logger.info(
f"Database connection to instance {db_config.instance_name} validated successfully"
Expand Down