Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/apps/datasource/api/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def inner():
return await asyncio.to_thread(inner)


# not used
@router.post("/fieldEnum/{id}")
async def field_enum(session: SessionDep, id: int):
def inner():
Expand Down
56 changes: 39 additions & 17 deletions backend/apps/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,21 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str:
return db_url


def get_extra_config(conf: DatasourceConf):
config_dict = {}
if conf.extraJdbc:
config_arr = conf.extraJdbc.split("&")
for config in config_arr:
kv = config.split("=")
if len(kv) == 2 and kv[0] and kv[1]:
config_dict[kv[0]] = kv[1]
else:
raise Exception(f'param: {config} is error')
return config_dict


def get_origin_connect(type: str, conf: DatasourceConf):
extra_config_dict = get_extra_config(conf)
if type == "sqlServer":
return pymssql.connect(
server=conf.host,
Expand All @@ -81,10 +95,12 @@ def get_origin_connect(type: str, conf: DatasourceConf):
password=conf.password,
database=conf.database,
timeout=conf.timeout,
tds_version='7.0' # options: '4.2', '7.0', '8.0' ...
tds_version='7.0', # options: '4.2', '7.0', '8.0' ...,
**extra_config_dict
)


# use sqlalchemy
def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
if conf.timeout is None:
Expand Down Expand Up @@ -135,9 +151,10 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
return False
else:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1', timeout=10).fetchall()
SQLBotLogUtil.info("success")
Expand All @@ -150,7 +167,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=10,
read_timeout=10) as conn, conn.cursor() as cursor:
read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1')
SQLBotLogUtil.info("success")
Expand All @@ -164,7 +181,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database,
user=conf.username,
password=conf.password,
timeout=10) as conn, conn.cursor() as cursor:
timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1')
SQLBotLogUtil.info("success")
Expand Down Expand Up @@ -221,16 +238,17 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema):
res = result.fetchall()
version = res[0][0]
else:
extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
cursor.execute(sql, timeout=10)
cursor.execute(sql, timeout=10, **extra_config_dict)
res = cursor.fetchall()
version = res[0][0]
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=10,
read_timeout=10) as conn, conn.cursor() as cursor:
read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
version = res[0][0]
Expand Down Expand Up @@ -260,17 +278,18 @@ def get_schema(ds: CoreDatasource):
res_list = [item[0] for item in res]
return res_list
else:
extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute("""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout)
res = cursor.fetchall()
res_list = [item[0] for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute("""SELECT nspname FROM pg_namespace""")
res = cursor.fetchall()
res_list = [item[0] for item in res]
Expand All @@ -288,25 +307,26 @@ def get_tables(ds: CoreDatasource):
res_list = [TableSchema(*item) for item in res]
return res_list
else:
extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql, {"param": sql_param}, timeout=conf.timeout)
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql, (sql_param,))
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql, (sql_param,))
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
Expand All @@ -328,25 +348,26 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
res_list = [ColumnSchema(*item) for item in res]
return res_list
else:
extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql, {"param1": p1, "param2": p2}, timeout=conf.timeout)
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql, (p1, p2))
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql, (p1, p2))
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
Expand Down Expand Up @@ -379,9 +400,10 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
raise ParseSQLResultError(str(ex))
else:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql, timeout=conf.timeout)
res = cursor.fetchall()
Expand All @@ -400,7 +422,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql)
res = cursor.fetchall()
Expand All @@ -419,7 +441,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql)
res = cursor.fetchall()
Expand Down