Skip to content

Commit 9be8bb9

Browse files
authored
Merge pull request #50 from pattern-tech/feat/wallet-accounting
Feat/wallet accounting
2 parents d4d9a86 + a018b29 commit 9be8bb9

File tree

2 files changed

+81
-5
lines changed

2 files changed

+81
-5
lines changed

api/src/conversation/routers/playground_conversation_router.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from uuid import UUID
22
from enum import Enum
3-
from typing import List, Dict
43
from pydantic import BaseModel
54
from sqlalchemy.orm import Session
5+
from typing import List, Optional, Dict
66
from fastapi.responses import StreamingResponse
77
from fastapi import APIRouter, Depends, HTTPException, status
88

@@ -47,6 +47,7 @@ class CreateConversationInput(BaseModel):
4747
"""
4848
name: str
4949
project_id: UUID
50+
conversation_id: Optional[UUID] = None
5051

5152
class Config:
5253
from_attributes = True
@@ -82,6 +83,13 @@ class MessageInput(BaseModel):
8283
stream: bool = True
8384

8485

86+
class FirstMessage(BaseModel):
87+
"""
88+
Schema for the first message.
89+
"""
90+
message: str
91+
92+
8593
@router.post(
8694
"",
8795
response_model=ConversationOutput,
@@ -93,7 +101,7 @@ def create_conversation(
93101
input: CreateConversationInput,
94102
db: Session = Depends(get_db),
95103
service: ConversationService = Depends(get_conversation_service),
96-
user_id: UUID = Depends(authenticate_user),
104+
user_id: UUID = Depends(authenticate_user)
97105
):
98106
"""
99107
Create a new conversation.
@@ -108,7 +116,7 @@ def create_conversation(
108116
"""
109117
try:
110118
conversation = service.create_conversation(
111-
db, input.name, input.project_id, user_id)
119+
db, input.name, input.project_id, user_id, input.conversation_id)
112120
return global_response(conversation)
113121
except Exception as e:
114122
raise HTTPException(
@@ -326,3 +334,28 @@ async def send_message(
326334
except Exception as e:
327335
raise HTTPException(
328336
status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
337+
338+
339+
@router.post(
340+
"/{project_id}/{conversation_id}/title-generation",
341+
summary="Auto Title Generation",
342+
description="Using LLM to generate title for conversation",
343+
response_description="LLM title generated"
344+
)
345+
async def generate_title(
346+
input: FirstMessage,
347+
project_id: UUID,
348+
conversation_id: UUID,
349+
db: Session = Depends(get_db),
350+
conversation_service: ConversationService = Depends(
351+
get_conversation_service),
352+
user_id: UUID = Depends(authenticate_user),
353+
):
354+
try:
355+
title = conversation_service.rename_title(
356+
db, conversation_id, user_id, input.message)
357+
return title
358+
except Exception as e:
359+
raise HTTPException(
360+
status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)
361+
)

api/src/conversation/services/conversation_service.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from uuid import UUID
22
from typing import List
3-
from datetime import timedelta
43
from sqlalchemy.orm import Session
54
from fastapi import HTTPException, status
65
from langchain_core.messages.human import HumanMessage
76

87
from src.util.configuration import Config
98
from src.agentflow.agents.hub import AgentHub
109
from src.db.models import Conversation, QueryUsage
10+
from src.agentflow.utils.shared_tools import init_llm
1111
from src.user.services.user_service import UserService
1212
from src.agent.services.memory_service import MemoryService
1313
from src.project.services.project_service import ProjectService
@@ -31,7 +31,7 @@ def __init__(self):
3131
self.user_service = UserService()
3232

3333
def create_conversation(
34-
self, db_session: Session, name: str, project_id: UUID, user_id: UUID
34+
self, db_session: Session, name: str, project_id: UUID, user_id: UUID, conversation_id: UUID = None
3535
) -> Conversation:
3636
"""
3737
Creates a new conversation.
@@ -41,6 +41,7 @@ def create_conversation(
4141
name (str): The name of the conversation.
4242
project_id (UUID): The ID of the project the conversation belongs to.
4343
user_id (UUID): The ID of the user creating the conversation.
44+
conversation_id (UUID): The ID of the conversation (optional).
4445
4546
Returns:
4647
Conversation: The created conversation instance.
@@ -53,6 +54,11 @@ def create_conversation(
5354

5455
conversation = Conversation(
5556
name=name, project_id=project_id, user_id=user_id)
57+
58+
# conversation id is generated in frontend
59+
if conversation_id:
60+
conversation.id = conversation_id
61+
5662
return self.repository.create(db_session, conversation)
5763

5864
def get_conversation(
@@ -258,9 +264,46 @@ def get_history(
258264
memory = self.memory_service.get_memory(conversation_id)
259265

260266
history = []
267+
generated_id = 0
261268
for message in memory.messages:
262269
history.append({
270+
"id": generated_id,
263271
"role": "human" if isinstance(message, HumanMessage) else "ai",
264272
"content": message.content,
265273
})
274+
generated_id += 1
266275
return history
276+
277+
def rename_title(self, db_session: Session, conversation_id: UUID, user_id: UUID, message: str) -> str:
278+
"""
279+
Renames a conversation title using the LLM to generate a title.
280+
281+
Args:
282+
db_session (Session): The database session.
283+
conversation_id (UUID): The ID of the conversation to rename.
284+
user_id (UUID): The ID of the user renaming the conversation.
285+
message (str): The message to generate a title for.
286+
287+
Returns:
288+
str: The new title of the conversation.
289+
"""
290+
config = Config.get_config()
291+
self.llm = init_llm(service=config["llm"]["provider"],
292+
model_name=config["llm"]["model"],
293+
api_key=config["llm"]["api_key"],
294+
stream=False,
295+
callbacks=None)
296+
messages = [
297+
(
298+
"system",
299+
"You should generate a title (maximum 5 words) for this message",
300+
),
301+
("human", f"{message}"),
302+
]
303+
304+
title = self.llm.invoke(messages)
305+
306+
self.repository.update(db_session, conversation_id, {
307+
"name": title.content}, user_id)
308+
309+
return title.content

0 commit comments

Comments
 (0)