Skip to content

Commit 0b588c4

Browse files
committed
feat: support Elasticsearch datasource #108
1 parent 10129b7 commit 0b588c4

File tree

12 files changed

+163
-24
lines changed

12 files changed

+163
-24
lines changed

backend/apps/chat/task/llm.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def generate_assistant_dynamic_sql(self, sql, tables: List):
625625
result_dict = {}
626626
for table in ds.tables:
627627
if table.name in tables and table.sql:
628-
#sub_query.append({"table": table.name, "query": table.sql})
628+
# sub_query.append({"table": table.name, "query": table.sql})
629629
result_dict[table.name] = table.sql
630630
sub_query.append({"table": table.name, "query": f'{dynamic_subsql_prefix}{table.name}'})
631631
if not sub_query:
@@ -881,7 +881,7 @@ def save_sql_data(self, data_obj: Dict[str, Any]):
881881
def finish(self):
882882
return finish_record(session=self.session, record_id=self.record.id)
883883

884-
def execute_sql(self, sql: str):
884+
def execute_sql(self, sql: str, tables):
885885
"""Execute SQL query
886886
887887
Args:
@@ -893,7 +893,7 @@ def execute_sql(self, sql: str):
893893
"""
894894
SQLBotLogUtil.info(f"Executing SQL on ds_id {self.ds.id}: {sql}")
895895
try:
896-
return exec_sql(self.ds, sql)
896+
return exec_sql(ds=self.ds, sql=sql, origin_column=False, table_name=tables)
897897
except Exception as e:
898898
if isinstance(e, ParseSQLResultError):
899899
raise e
@@ -1000,14 +1000,16 @@ def run_task(self, in_chat: bool = True):
10001000
sqlbot_temp_sql_text = None
10011001
assistant_dynamic_sql = None
10021002
# todo row permission
1003-
if ((not self.current_assistant or is_page_embedded) and is_normal_user(self.current_user)) or use_dynamic_ds:
1003+
if ((not self.current_assistant or is_page_embedded) and is_normal_user(
1004+
self.current_user)) or use_dynamic_ds:
10041005
sql, tables = self.check_sql(res=full_sql_text)
10051006
sql_result = None
1006-
1007+
10071008
if use_dynamic_ds:
10081009
dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables)
1009-
sqlbot_temp_sql_text = dynamic_sql_result.get('sqlbot_temp_sql_text') if dynamic_sql_result else None
1010-
#sql_result = self.generate_assistant_filter(sql, tables)
1010+
sqlbot_temp_sql_text = dynamic_sql_result.get(
1011+
'sqlbot_temp_sql_text') if dynamic_sql_result else None
1012+
# sql_result = self.generate_assistant_filter(sql, tables)
10111013
else:
10121014
sql_result = self.generate_filter(sql, tables) # maybe no sql and tables
10131015

@@ -1020,6 +1022,7 @@ def run_task(self, in_chat: bool = True):
10201022
sql = self.check_save_sql(res=full_sql_text)
10211023
else:
10221024
sql = self.check_save_sql(res=full_sql_text)
1025+
tables = []
10231026

10241027
SQLBotLogUtil.info(sql)
10251028
format_sql = sqlparse.format(sql, reindent=True)
@@ -1033,10 +1036,11 @@ def run_task(self, in_chat: bool = True):
10331036
if sqlbot_temp_sql_text and assistant_dynamic_sql:
10341037
dynamic_sql_result.pop('sqlbot_temp_sql_text')
10351038
for origin_table, subsql in dynamic_sql_result.items():
1036-
assistant_dynamic_sql = assistant_dynamic_sql.replace(f'{dynamic_subsql_prefix}{origin_table}', subsql)
1039+
assistant_dynamic_sql = assistant_dynamic_sql.replace(f'{dynamic_subsql_prefix}{origin_table}',
1040+
subsql)
10371041
real_execute_sql = assistant_dynamic_sql
1038-
1039-
result = self.execute_sql(sql=real_execute_sql)
1042+
1043+
result = self.execute_sql(sql=real_execute_sql, tables=tables)
10401044
self.save_sql_data(data_obj=result)
10411045
if in_chat:
10421046
yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n'

backend/apps/db/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class DB(Enum):
2222
dm = ('dm', '"', '"', ConnectType.py_driver)
2323
doris = ('doris', '`', '`', ConnectType.py_driver)
2424
redshift = ('redshift', '"', '"', ConnectType.py_driver)
25+
es = ('es', '"', '"', ConnectType.py_driver)
2526

2627
def __init__(self, type, prefix, suffix, connect_type: ConnectType):
2728
self.type = type

backend/apps/db/db.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from common.core.deps import Trans
2525
from common.utils.utils import SQLBotLogUtil
2626
from fastapi import HTTPException
27+
from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data
2728

2829

2930
def get_uri(ds: CoreDatasource) -> str:
@@ -144,7 +145,8 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
144145
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
145146
return False
146147
elif ds.type == 'redshift':
147-
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
148+
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database,
149+
user=conf.username,
148150
password=conf.password,
149151
timeout=10) as conn, conn.cursor() as cursor:
150152
try:
@@ -156,6 +158,14 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
156158
if is_raise:
157159
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
158160
return False
161+
elif ds.type == 'es':
162+
es_conn = get_es_connect(conf)
163+
if es_conn.ping():
164+
SQLBotLogUtil.info("success")
165+
return True
166+
else:
167+
SQLBotLogUtil.info("failed")
168+
return False
159169
else:
160170
conn = get_ds_engine(ds)
161171
try:
@@ -208,7 +218,7 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema):
208218
cursor.execute(sql)
209219
res = cursor.fetchall()
210220
version = res[0][0]
211-
elif ds.type == 'redshift':
221+
elif ds.type == 'redshift' or ds.type == 'es':
212222
version = ''
213223
except Exception as e:
214224
print(e)
@@ -285,6 +295,10 @@ def get_tables(ds: CoreDatasource):
285295
res = cursor.fetchall()
286296
res_list = [TableSchema(*item) for item in res]
287297
return res_list
298+
elif ds.type == 'es':
299+
res = get_es_index(conf)
300+
res_list = [TableSchema(*item) for item in res]
301+
return res_list
288302

289303

290304
def get_fields(ds: CoreDatasource, table_name: str = None):
@@ -321,9 +335,13 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
321335
res = cursor.fetchall()
322336
res_list = [ColumnSchema(*item) for item in res]
323337
return res_list
338+
elif ds.type == 'es':
339+
res = get_es_fields(conf, table_name)
340+
res_list = [ColumnSchema(*item) for item in res]
341+
return res_list
324342

325343

326-
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False):
344+
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False, table_name=None):
327345
while sql.endswith(';'):
328346
sql = sql[:-1]
329347

@@ -401,3 +419,16 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
401419
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
402420
except Exception as ex:
403421
raise ParseSQLResultError(str(ex))
422+
elif ds.type == 'es':
423+
if table_name and table_name[0]:
424+
res, columns = get_es_data(conf, sql, table_name[0])
425+
columns = [field[0] for field in columns] if origin_column else [field[0].lower() for
426+
field in
427+
columns]
428+
result_list = [
429+
{str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
430+
enumerate(tuple_item)}
431+
for tuple_item in res
432+
]
433+
return {"fields": columns, "data": result_list,
434+
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}

backend/apps/db/es_engine.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Author: Junjun
2+
# Date: 2025/9/9
3+
4+
import json
5+
6+
import requests
7+
from elasticsearch import Elasticsearch
8+
9+
from apps.datasource.models.datasource import DatasourceConf
10+
11+
12+
def get_es_connect(conf: DatasourceConf):
13+
es_client = Elasticsearch(
14+
[conf.host], # ES address
15+
basic_auth=(conf.username, conf.password),
16+
verify_certs=False,
17+
compatibility_mode=True
18+
)
19+
return es_client
20+
21+
22+
# get tables
23+
def get_es_index(conf: DatasourceConf):
24+
es_client = get_es_connect(conf)
25+
indices = es_client.cat.indices(format="json")
26+
res = []
27+
for idx in indices:
28+
index_name = idx.get('index')
29+
desc = ''
30+
# get mapping
31+
mapping = es_client.indices.get_mapping(index=index_name)
32+
mappings = mapping.get(index_name).get("mappings")
33+
if mappings.get('_meta'):
34+
desc = mappings.get('_meta').get('description')
35+
res.append((index_name, desc))
36+
return res
37+
38+
39+
# get fields
40+
def get_es_fields(conf: DatasourceConf, table_name: str):
41+
es_client = get_es_connect(conf)
42+
index_name = table_name
43+
mapping = es_client.indices.get_mapping(index=index_name)
44+
properties = mapping.get(index_name).get("mappings").get("properties")
45+
res = []
46+
for field, config in properties.items():
47+
field_type = config.get("type")
48+
desc = ''
49+
if config.get("_meta"):
50+
desc = config.get("_meta").get('description')
51+
52+
if field_type:
53+
res.append((field, field_type, desc))
54+
else:
55+
# object、nested...
56+
res.append((field, ','.join(list(config.keys())), desc))
57+
return res
58+
59+
60+
def get_es_data(conf: DatasourceConf, sql: str, table_name: str):
61+
r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql})
62+
# print(json.dumps(r.json()))
63+
64+
es_client = get_es_connect(conf)
65+
response = es_client.search(
66+
index=table_name,
67+
body=json.dumps(r.json())
68+
)
69+
70+
# print(response)
71+
fields = get_es_fields(conf, table_name)
72+
res = []
73+
for hit in response.get('hits').get('hits'):
74+
item = []
75+
if 'fields' in hit:
76+
result = hit.get('fields') # {'title': ['Python'], 'age': [30]}
77+
for field in fields:
78+
v = result.get(field[0])
79+
item.append(v[0]) if v else item.append(None)
80+
res.append(tuple(item))
81+
# print(hit['fields']['title'][0])
82+
# elif '_source' in hit:
83+
# print(hit.get('_source'))
84+
return res, fields

backend/apps/db/type.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ def db_type_relation() -> Dict:
1313
"ck": "ClickHouse",
1414
"dm": "达梦",
1515
"doris": "Apache Doris",
16-
"redshift": "AWS Redshift"
16+
"redshift": "AWS Redshift",
17+
"es": "Elasticsearch"
1718
}

backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ dependencies = [
5050
"dicttoxml>=1.7.16",
5151
"dmpython>=2.5.22; platform_system != 'Darwin'",
5252
"redshift-connector>=2.1.8",
53+
"elasticsearch[requests] (>=7.10,<8.0)",
5354
]
5455

5556
[project.optional-dependencies]

backend/template.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ template:
7070
生成SQL时,必须避免与数据库关键字冲突
7171
</rule>
7272
<rule>
73-
如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift,则在schema、表名、字段名、别名外层加双引号;
73+
如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号;
7474
如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号;
7575
如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。
7676
<example>
@@ -448,7 +448,7 @@ template:
448448
- 如果存在冗余的过滤条件则进行去重后再生成新SQL。
449449
- 给过滤条件中的字段前加上表别名(如果没有表别名则加表名),如:table.field。
450450
- 生成SQL时,必须避免关键字冲突:
451-
- 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift,则在schema、表名、字段名、别名外层加双引号;
451+
- 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号;
452452
- 如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号;
453453
- 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。
454454
- 生成的SQL使用JSON格式返回:
1.01 KB
Loading

frontend/src/i18n/en.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@
256256
"success": "Connect success",
257257
"failed": "Connect failed"
258258
},
259-
"timeout": "Timeout(second)"
259+
"timeout": "Timeout(second)",
260+
"address": "Address"
260261
}
261262
},
262263
"datasource": {

frontend/src/i18n/zh-CN.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@
256256
"success": "连接成功",
257257
"failed": "连接失败"
258258
},
259-
"timeout": "查询超时(秒)"
259+
"timeout": "查询超时(秒)",
260+
"address": "地址"
260261
}
261262
},
262263
"datasource": {

0 commit comments

Comments
 (0)