Skip to content

Commit 22d3cd4

Browse files
committed
feat: ai memory
1 parent 9597223 commit 22d3cd4

15 files changed

Lines changed: 250 additions & 31 deletions

build.gradle

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ dependencies {
1717
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
1818
implementation 'org.springframework.boot:spring-boot-starter-oauth2-resource-server'
1919
implementation 'org.springframework.boot:spring-boot-starter-web'
20-
implementation platform("org.springframework.ai:spring-ai-bom:1.0.0")
21-
implementation 'org.springframework.ai:spring-ai-starter-model-openai'
2220
implementation 'org.springframework.boot:spring-boot-starter-security'
2321
implementation 'org.springdoc:springdoc-openapi-starter-webmvc-ui:2.6.0'
2422
implementation 'org.springframework.boot:spring-boot-starter-validation'
2523

24+
// Spring AI
25+
implementation platform("org.springframework.ai:spring-ai-bom:1.1.2")
26+
implementation 'org.springframework.ai:spring-ai-starter-model-chat-memory-repository-jdbc'
27+
implementation 'org.springframework.ai:spring-ai-starter-model-openai'
28+
2629
// Stats
2730
implementation 'org.springframework.boot:spring-boot-starter-actuator'
2831
implementation 'io.micrometer:micrometer-registry-prometheus'

src/main/java/com/example/spring/app/llm/LLMAnswerDTO.java

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package com.example.spring.app.llm;
2+
3+
import org.springframework.ai.chat.client.ChatClient;
4+
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
5+
import org.springframework.ai.chat.memory.ChatMemory;
6+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
7+
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
8+
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
9+
import org.springframework.ai.chat.memory.repository.jdbc.PostgresChatMemoryRepositoryDialect;
10+
import org.springframework.context.annotation.Bean;
11+
import org.springframework.context.annotation.Configuration;
12+
import org.springframework.jdbc.core.JdbcTemplate;
13+
14+
@Configuration
15+
public class LLMConfig {
16+
@Bean
17+
public ChatClient chatClient(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory) {
18+
return chatClientBuilder
19+
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build())
20+
.defaultSystem(
21+
"You are a helpful assistant which is named Pierre. You are developed by the goat Mathieu." +
22+
"You always answer concisely and clearly. Don't be verbose." +
23+
"Do not translate into another language unless explicitly asked. " +
24+
"Very important: Always respond in Markdown." +
25+
"Very important: Use a Marseillais accent when speaking french."
26+
)
27+
.build();
28+
}
29+
30+
@Bean
31+
public ChatMemory jdbcChatMemory(JdbcTemplate jdbcTemplate) {
32+
ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder()
33+
.jdbcTemplate(jdbcTemplate)
34+
.dialect(new PostgresChatMemoryRepositoryDialect())
35+
.build();
36+
37+
return MessageWindowChatMemory.builder()
38+
.chatMemoryRepository(chatMemoryRepository)
39+
.maxMessages(10)
40+
.build();
41+
}
42+
}
Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,81 @@
11
package com.example.spring.app.llm;
22

3+
import com.example.spring.app.llm.dto.ChatStreamResponseDTO;
4+
import com.example.spring.app.llm.dto.ConversationDTO;
5+
import com.example.spring.app.llm.dto.MessageDTO;
6+
import com.example.spring.app.llm.springAiChatMemory.SpringAiChatMemoryService;
7+
import com.example.spring.app.llm.userConversation.UserConversationModel;
8+
import com.example.spring.app.llm.userConversation.UserConversationService;
39
import org.springframework.ai.chat.client.ChatClient;
4-
import org.springframework.ai.chat.model.ChatResponse;
10+
import org.springframework.ai.chat.memory.ChatMemory;
511
import org.springframework.http.MediaType;
612
import org.springframework.web.bind.annotation.*;
713
import reactor.core.publisher.Flux;
814

15+
import java.time.Instant;
16+
import java.util.List;
17+
18+
import static com.example.spring.common.utils.JwtUtil.extractUserIdFromHeader;
19+
920
@CrossOrigin
1021
@RestController
11-
@RequestMapping("/v1")
22+
@RequestMapping("/v1/chat")
1223
public class LLMController {
1324

1425
private final ChatClient chatClient;
26+
private final UserConversationService userConversationService;
27+
private final SpringAiChatMemoryService springAiChatMemoryService;
1528

16-
public LLMController(ChatClient.Builder chatClientBuilder) {
17-
this.chatClient = chatClientBuilder.build();
29+
public LLMController(ChatClient chatClient, UserConversationService userConversationService, SpringAiChatMemoryService springAiChatMemoryService) {
30+
this.chatClient = chatClient;
31+
this.userConversationService = userConversationService;
32+
this.springAiChatMemoryService = springAiChatMemoryService;
1833
}
1934

20-
@GetMapping("/ask-ai")
21-
LLMAnswerDTO generation(String userInput) {
22-
LLMAnswerDTO llmAnswerDTO = new LLMAnswerDTO();
23-
String response = this.chatClient.prompt()
24-
.user(userInput)
25-
.call()
26-
.content();
35+
@PostMapping(value = "/conversation/{conversationId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
36+
public Flux<ChatStreamResponseDTO> streamGeneration(@PathVariable String conversationId, @RequestBody LLMRequest request) {
37+
String userId = extractUserIdFromHeader();
2738

28-
llmAnswerDTO.setAnswer(response);
29-
return llmAnswerDTO;
30-
}
39+
UserConversationModel conversation =
40+
conversationId.equals("new")
41+
? userConversationService.createNewConversationForUser(userId)
42+
: userConversationService.getUserConversation(conversationId, userId);
43+
44+
if (conversation == null) {
45+
throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID.");
46+
}
3147

32-
@GetMapping(value = "/stream-ai", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
33-
public Flux<ChatResponse> streamGeneration(@RequestParam String userInput) {
3448
return chatClient.prompt()
35-
.user(userInput)
49+
.user(userSpec -> userSpec.text(request.userInput()))
50+
.advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, conversation.getConversationId()))
3651
.stream()
37-
.chatResponse();
52+
.chatResponse()
53+
.map(chatResponse -> new ChatStreamResponseDTO(conversation.getConversationId(), chatResponse, Instant.now().toEpochMilli()));
54+
}
55+
56+
57+
// TODO: Add pagination to this endpoint
58+
@GetMapping("/conversation/history/{conversationId}")
59+
public List<MessageDTO> getConversationHistory(@PathVariable String conversationId) {
60+
String userId = extractUserIdFromHeader();
61+
UserConversationModel conversation = userConversationService.getUserConversation(conversationId, userId);
62+
63+
if (conversation == null) {
64+
throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID.");
65+
}
66+
67+
return springAiChatMemoryService.findAllByConversationId(conversationId).stream()
68+
.map(chatMemoryModel -> new MessageDTO(chatMemoryModel.getContent(), chatMemoryModel.getType(), chatMemoryModel.getTimestamp()))
69+
.toList();
70+
}
71+
72+
@GetMapping("/conversation/all")
73+
public List<ConversationDTO> getAllUserConversations() {
74+
String userId = extractUserIdFromHeader();
75+
List<UserConversationModel> conversations = userConversationService.getAllConversationsForUser(userId);
76+
77+
return conversations.stream()
78+
.map(conv -> new ConversationDTO(conv.getConversationId()))
79+
.toList();
3880
}
3981
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package com.example.spring.app.llm;
2+
3+
public record LLMRequest(
4+
String userInput
5+
){}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package com.example.spring.app.llm.dto;
2+
3+
import org.springframework.ai.chat.model.ChatResponse;
4+
5+
public record ChatStreamResponseDTO(
6+
String conversationId,
7+
ChatResponse chatResponse,
8+
Long timestamp
9+
){}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package com.example.spring.app.llm.dto;
2+
3+
public record ConversationDTO(
4+
String conversationId
5+
//String title
6+
) {}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package com.example.spring.app.llm.dto;
2+
3+
import org.springframework.ai.chat.messages.MessageType;
4+
5+
import java.time.LocalDateTime;
6+
7+
public record MessageDTO(
8+
String message,
9+
MessageType messageType,
10+
LocalDateTime timestamp
11+
) {}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package com.example.spring.app.llm.springAiChatMemory;
2+
3+
import jakarta.persistence.*;
4+
import lombok.Getter;
5+
import lombok.Setter;
6+
import org.hibernate.validator.constraints.Length;
7+
import org.springframework.ai.chat.messages.MessageType;
8+
9+
import java.time.LocalDateTime;
10+
11+
@Setter
12+
@Getter
13+
@Entity
14+
@Table(name = "spring_ai_chat_memory")
15+
public class SpringAiChatMemoryModel {
16+
@Id
17+
@GeneratedValue(strategy = GenerationType.IDENTITY)
18+
private Integer id;
19+
20+
@Length(max = 36)
21+
private String conversationId;
22+
23+
@Enumerated(EnumType.STRING)
24+
@Column(length = 10, nullable = false)
25+
private MessageType type;
26+
27+
private String content;
28+
private LocalDateTime timestamp;
29+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package com.example.spring.app.llm.springAiChatMemory;
2+
3+
import org.springframework.data.jpa.repository.JpaRepository;
4+
5+
import java.util.List;
6+
7+
public interface SpringAiChatMemoryRepository extends JpaRepository<SpringAiChatMemoryModel, Integer> {
8+
List<SpringAiChatMemoryModel> findAllByConversationId(String conversationId);
9+
}

0 commit comments

Comments
 (0)