-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
320 lines (267 loc) · 11 KB
/
main.py
File metadata and controls
320 lines (267 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import os
from typing import Annotated
from uuid import uuid4
import json
from bson.objectid import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from fastapi.responses import JSONResponse, StreamingResponse, RedirectResponse
import httpx
import uvicorn
import shutil
from config import *
from text_splitters import character_splitter, get_split_docs
from rag import main as rag_main
from backend.meetings import router as meeting_router
from backend.mongo_config import *
from stt_inference import transcribe_audio_files_in_directory_with_model
from audio_splitter import split_audio
from mongodb_manager import save_to_mongoDB, delete_mongoDB_data, init_mongoDB, show_mongoDB_data
from utils.seoul_time import get_current_time_str
from vectordb_manager import (
load_faiss_index,
save_faiss_index,
create_faiss_index_from_documents,
init_and_save_faiss_index,
show_faiss_index,
delete_faiss_index,
create_documents_from_texts,
put_metadata_to_documents,
add_documents_to_faiss_index
)
audio_files_path = os.path.join(PATH, "audio_files")
async def upload_to_gridfs(file: UploadFile, bucket: AsyncIOMotorGridFSBucket) -> str:
grid_in = bucket.open_upload_stream(file.filename)
file_content = await file.read()
await grid_in.write(file_content)
await grid_in.close()
return str(grid_in._id)
app = FastAPI()
app.include_router(meeting_router, prefix="/meetings", tags=["meetings"])
@app.on_event("startup")
async def startup_event():
print("Server is starting up...") # 시작 시 설정 있으면 구현 예정
@app.get("/")
def read_root():
return {"Hello": "World"}
def init_local_data():
try:
shutil.rmtree(audio_files_path)
except FileNotFoundError:
print(f"The directory {audio_files_path} does not exist.")
os.makedirs(audio_files_path, exist_ok=True)
@app.delete("/initialization")
async def read_root():
init_and_save_faiss_index()
await init_mongoDB()
init_local_data()
return {"Init": "Complete"}
@app.get("/success")
async def success():
return {"status": "success", "detail": "방금 했던 요청 성공"}
@app.get("/showdb")
async def show_data():
result = {"data": {"FAISS": show_faiss_index(), "MongoDB": await show_mongoDB_data() } }
return result
@app.get("/answer")
async def get_anawer(query: str):
if query is None:
raise HTTPException(status_code=400, detail="Query header not found")
result = rag_main(query, 10) # k개의 문서를 검색합니다.
return {"result": result}
@app.put("/documents") # init or merge FAISS index
async def put_meeting_data(data: Annotated[str | None, Header()] = None):
if data is None:
raise HTTPException(status_code=400, detail="Data header not found")
try:
data = json.loads(data)
transcript = data.get("transcript")
time = data.get("time") # <class 'str'>
meeting_id = data.get("meeting_id")
faiss = load_faiss_index()
transcripts = character_splitter.split_text(transcript)
documents = create_documents_from_texts(transcripts)
metadata = {"time": time, "meeting_id": meeting_id}
documents = put_metadata_to_documents(documents, metadata)
new_faiss = create_faiss_index_from_documents(documents)
if faiss.index.ntotal > 0:
new_faiss.merge_from(faiss)
try:
save_faiss_index(new_faiss)
except Exception as e:
print("Error:", e)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return RedirectResponse(url="/success", status_code=303)
@app.put("/document") # add
async def add_meeting_data(data: Annotated[str | None, Header()] = None):
try:
data = json.loads(data)
uuid = data.get("uuid"); title = data.get("title"); created_date = data.get("created_date"); txt_path = data.get("txt_path")
page_content=''
try:
with open(txt_path, 'r', encoding='utf-8') as file:
page_content = file.read()
except FileNotFoundError:
print(f"The file {txt_path} does not exist.")
await save_to_mongoDB(uuid, page_content, title, created_date)
add_documents_to_faiss_index(get_split_docs(txt_path, uuid))
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))
return {"file_id": uuid, "detail": "Upload Success"}
def delete_local_data(doc_id, audio_files_path):
try:
shutil.rmtree(os.path.join(audio_files_path, doc_id))
except FileNotFoundError:
print(f"The directory {doc_id} does not exist.")
@app.delete("/documents") # delete
async def delete_document(doc_id: Annotated[str | None, Header(convert_underscores=False)] = None):
if doc_id is None:
raise HTTPException(status_code=400, detail="doc_id header not found")
try:
delete_faiss_index(doc_id)
await delete_mongoDB_data(doc_id)
delete_local_data(doc_id, audio_files_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return RedirectResponse(url="/success", status_code=303)
def save_audio_to_local(file: UploadFile, save_path):
try:
uuid = uuid4().hex
uuid_path = os.path.join(audio_files_path, uuid)
os.makedirs(uuid_path, exist_ok=True)
with open(os.path.join(uuid_path, file.filename), "wb") as f:
shutil.copyfileobj(file.file, f)
return uuid, os.path.join(uuid_path, file.filename)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def segment_and_STT(file_path) -> str:
try:
# Segment
output_path = os.path.join(os.path.dirname(file_path), "outputs")
num_files = split_audio(file_path, output_dir=output_path) # mp3파일 위치, 분할된 파일 저장 폴더
# STT
transcriptions = transcribe_audio_files_in_directory_with_model(
output_path,
model=STT_MODEL,
processor=PROCESSOR,
device=DEVICE
)
transcriptions = "\n\n".join(transcriptions)
# 확장자를 .txt로 변경해서 회의록 저장
base, _ = os.path.splitext(file_path)
txt_file_path = base + '.txt'
with open(txt_file_path, 'w', encoding='utf-8') as file:
file.write(transcriptions)
return transcriptions, txt_file_path
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/process")
async def process_all(file: UploadFile = File(...)):
'''
유저가 mp3를 업로드
mp3 로컬에 저장
로컬에 저장한 경로를 참조해서 mp3 파일 분할
분할된 파일 STT 수행 -> 전체 회의록 텍스트 제작
data_path, title, created_date 선언
이후의 로직 수행(save_to_mondoDB, add_faiss_document~)
'''
try:
audio_files_directory = 'audio_files' # 전역 변수로 뺄 수도 있음
# 로컬에 저장
uuid, file_path = save_audio_to_local(file, audio_files_directory)
# 분할, STT 수행 + 제목, 생성일 지정
page_content, txt_path = segment_and_STT(file_path)
title = os.path.basename(file_path) # mp3파일 이름을 title로 지정
created_date = get_current_time_str()
# 요청 데이터 준비
data= {
"uuid": uuid,
"page_content": page_content,
"title": title,
"created_date": created_date,
"txt_path": txt_path
}
headers = {"Content-Type": "application/json", "data": json.dumps(data)}
async with httpx.AsyncClient() as client:
response = await client.put("http://127.0.0.1:8000/document", headers=headers) #, headers=json.dumps(headers))
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))
return {"response": response.json()}
# return RedirectResponse(url="/success", status_code=303)
@app.post("/files")
async def upload_file(file: UploadFile = File(...)):
# GridFS에 파일 업로드
client = AsyncIOMotorClient(MONGO_URI)
db = client[DATABASE_NAME]
bucket = AsyncIOMotorGridFSBucket(db)
file_id = await upload_to_gridfs(file, bucket)
client.close()
await file.seek(0)
# 로컬에 임시 저장
uuid = uuid4().hex
uuid_path = os.path.join(audio_files_path, uuid)
os.makedirs(uuid_path, exist_ok=True)
with open(os.path.join(uuid_path, file.filename), "wb") as f:
f.write(await file.read())
return JSONResponse(status_code=200, content={"file_id": str(file_id), "uuid": uuid})
@app.post("/segment/{uuid}")
async def segment_audio(uuid: str):
uuid_path = os.path.join(audio_files_path, uuid)
if not os.path.exists(uuid_path):
raise HTTPException(status_code=404, detail="UUID not found")
file_name = os.listdir(uuid_path)[0]
file_path = os.path.join(uuid_path, file_name)
output_dir = os.path.join(uuid_path, "outputs")
try:
num_files = split_audio(file_path, output_dir=output_dir)
return JSONResponse(status_code=200, content={"num_files": num_files})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stt/{uuid}")
async def stt(uuid: str):
uuid_path = os.path.join(audio_files_path, uuid)
if not os.path.exists(uuid_path):
raise HTTPException(status_code=404, detail="UUID not found")
file_path = os.path.join(uuid_path, "outputs")
try:
transcriptions = transcribe_audio_files_in_directory_with_model(
file_path,
model=STT_MODEL,
processor=PROCESSOR,
device=DEVICE
)
transcriptions = "\n\n".join(transcriptions)
return JSONResponse(status_code=200, content={"transcript": transcriptions})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/files/{file_id}")
async def download_file(file_id: str):
client = AsyncIOMotorClient(MONGO_URI)
db = client[DATABASE_NAME]
bucket = AsyncIOMotorGridFSBucket(db)
try:
file_id = ObjectId(file_id)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
try:
grid_out = await bucket.open_download_stream(file_id)
except Exception as e:
raise HTTPException(status_code=404, detail=str(e))
finally:
client.close()
async def file_iterator():
while True:
chunk = await grid_out.readchunk()
if not chunk:
break
yield chunk
headers = {
'Content-Disposition': f'attachment; filename="{grid_out.filename}"'
}
return StreamingResponse(file_iterator(), media_type='application/octet-stream', headers=headers)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)