|
1 | 1 | package com.example.spring.app.llm; |
2 | 2 |
|
| 3 | +import com.example.spring.app.llm.dto.ChatStreamResponseDTO; |
| 4 | +import com.example.spring.app.llm.dto.ConversationDTO; |
| 5 | +import com.example.spring.app.llm.dto.MessageDTO; |
| 6 | +import com.example.spring.app.llm.springAiChatMemory.SpringAiChatMemoryService; |
| 7 | +import com.example.spring.app.llm.userConversation.UserConversationModel; |
| 8 | +import com.example.spring.app.llm.userConversation.UserConversationService; |
3 | 9 | import org.springframework.ai.chat.client.ChatClient; |
4 | | -import org.springframework.ai.chat.model.ChatResponse; |
| 10 | +import org.springframework.ai.chat.memory.ChatMemory; |
5 | 11 | import org.springframework.http.MediaType; |
6 | 12 | import org.springframework.web.bind.annotation.*; |
7 | 13 | import reactor.core.publisher.Flux; |
8 | 14 |
|
| 15 | +import java.time.Instant; |
| 16 | +import java.util.List; |
| 17 | + |
| 18 | +import static com.example.spring.common.utils.JwtUtil.extractUserIdFromHeader; |
| 19 | + |
9 | 20 | @CrossOrigin |
10 | 21 | @RestController |
11 | | -@RequestMapping("/v1") |
| 22 | +@RequestMapping("/v1/chat") |
12 | 23 | public class LLMController { |
13 | 24 |
|
14 | 25 | private final ChatClient chatClient; |
| 26 | + private final UserConversationService userConversationService; |
| 27 | + private final SpringAiChatMemoryService springAiChatMemoryService; |
15 | 28 |
|
16 | | - public LLMController(ChatClient.Builder chatClientBuilder) { |
17 | | - this.chatClient = chatClientBuilder.build(); |
| 29 | + public LLMController(ChatClient chatClient, UserConversationService userConversationService, SpringAiChatMemoryService springAiChatMemoryService) { |
| 30 | + this.chatClient = chatClient; |
| 31 | + this.userConversationService = userConversationService; |
| 32 | + this.springAiChatMemoryService = springAiChatMemoryService; |
18 | 33 | } |
19 | 34 |
|
20 | | - @GetMapping("/ask-ai") |
21 | | - LLMAnswerDTO generation(String userInput) { |
22 | | - LLMAnswerDTO llmAnswerDTO = new LLMAnswerDTO(); |
23 | | - String response = this.chatClient.prompt() |
24 | | - .user(userInput) |
25 | | - .call() |
26 | | - .content(); |
| 35 | + @PostMapping(value = "/conversation/{conversationId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) |
| 36 | + public Flux<ChatStreamResponseDTO> streamGeneration(@PathVariable String conversationId, @RequestBody LLMRequest request) { |
| 37 | + String userId = extractUserIdFromHeader(); |
27 | 38 |
|
28 | | - llmAnswerDTO.setAnswer(response); |
29 | | - return llmAnswerDTO; |
30 | | - } |
| 39 | + UserConversationModel conversation = |
| 40 | + conversationId.equals("new") |
| 41 | + ? userConversationService.createNewConversationForUser(userId) |
| 42 | + : userConversationService.getUserConversation(conversationId, userId); |
| 43 | + |
| 44 | + if (conversation == null) { |
| 45 | + throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID."); |
| 46 | + } |
31 | 47 |
|
32 | | - @GetMapping(value = "/stream-ai", produces = MediaType.TEXT_EVENT_STREAM_VALUE) |
33 | | - public Flux<ChatResponse> streamGeneration(@RequestParam String userInput) { |
34 | 48 | return chatClient.prompt() |
35 | | - .user(userInput) |
| 49 | + .user(userSpec -> userSpec.text(request.userInput())) |
| 50 | + .advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, conversation.getConversationId())) |
36 | 51 | .stream() |
37 | | - .chatResponse(); |
| 52 | + .chatResponse() |
| 53 | + .map(chatResponse -> new ChatStreamResponseDTO(conversation.getConversationId(), chatResponse, Instant.now().toEpochMilli())); |
| 54 | + } |
| 55 | + |
| 56 | + |
| 57 | + // TODO: Add pagination to this endpoint |
| 58 | + @GetMapping("/conversation/history/{conversationId}") |
| 59 | + public List<MessageDTO> getConversationHistory(@PathVariable String conversationId) { |
| 60 | + String userId = extractUserIdFromHeader(); |
| 61 | + UserConversationModel conversation = userConversationService.getUserConversation(conversationId, userId); |
| 62 | + |
| 63 | + if (conversation == null) { |
| 64 | + throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID."); |
| 65 | + } |
| 66 | + |
| 67 | + return springAiChatMemoryService.findAllByConversationId(conversationId).stream() |
| 68 | + .map(chatMemoryModel -> new MessageDTO(chatMemoryModel.getContent(), chatMemoryModel.getType(), chatMemoryModel.getTimestamp())) |
| 69 | + .toList(); |
| 70 | + } |
| 71 | + |
| 72 | + @GetMapping("/conversation/all") |
| 73 | + public List<ConversationDTO> getAllUserConversations() { |
| 74 | + String userId = extractUserIdFromHeader(); |
| 75 | + List<UserConversationModel> conversations = userConversationService.getAllConversationsForUser(userId); |
| 76 | + |
| 77 | + return conversations.stream() |
| 78 | + .map(conv -> new ConversationDTO(conv.getConversationId())) |
| 79 | + .toList(); |
38 | 80 | } |
39 | 81 | } |
0 commit comments