diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 9fe134ed6..60b2f8935 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -49,6 +49,7 @@ import ( "github.com/apache/answer/internal/repo/collection" "github.com/apache/answer/internal/repo/comment" "github.com/apache/answer/internal/repo/config" + "github.com/apache/answer/internal/repo/embedding" "github.com/apache/answer/internal/repo/export" "github.com/apache/answer/internal/repo/file_record" "github.com/apache/answer/internal/repo/limit" @@ -87,6 +88,7 @@ import ( config2 "github.com/apache/answer/internal/service/config" "github.com/apache/answer/internal/service/content" "github.com/apache/answer/internal/service/dashboard" + embedding2 "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/eventqueue" export2 "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/feature_toggle" @@ -247,7 +249,9 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, reasonService := reason2.NewReasonService(reasonRepo) reasonController := controller.NewReasonController(reasonService) themeController := controller_admin.NewThemeController() - siteInfoService := siteinfo.NewSiteInfoService(siteInfoRepo, siteInfoCommonService, emailService, tagCommonService, configService, questionCommon, fileRecordService) + embeddingRepo := embedding.NewEmbeddingRepo(dataData) + embeddingService := embedding2.NewEmbeddingService(embeddingRepo, searchService, answerService, questionCommon, commentRepo, siteInfoCommonService) + siteInfoService := siteinfo.NewSiteInfoService(siteInfoRepo, siteInfoCommonService, emailService, tagCommonService, configService, questionCommon, fileRecordService, embeddingService) siteInfoController := controller_admin.NewSiteInfoController(siteInfoService) controllerSiteInfoController := controller.NewSiteInfoController(siteInfoCommonService) notificationCommon := notificationcommon.NewNotificationCommon(dataData, notificationRepo, userCommon, activityRepo, followRepo, objService, noticequeueService, userExternalLoginRepo, siteInfoCommonService) @@ -283,7 +287,7 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, apiKeyService := apikey.NewAPIKeyService(apiKeyRepo) adminAPIKeyController := controller_admin.NewAdminAPIKeyController(apiKeyService) featureToggleService := feature_toggle.NewFeatureToggleService(siteInfoRepo) - mcpController := controller.NewMCPController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, featureToggleService) + mcpController := controller.NewMCPController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, featureToggleService, embeddingService) aiConversationRepo := ai_conversation.NewAIConversationRepo(dataData) aiConversationService := ai_conversation2.NewAIConversationService(aiConversationRepo, userCommon) aiController := controller.NewAIController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, mcpController, aiConversationService, featureToggleService) diff --git a/docs/docs.go b/docs/docs.go index 57a23d432..8f65e05da 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -11712,6 +11712,24 @@ const docTemplate = `{ "type": "string", "maxLength": 256 }, + "embedding_crontab": { + "type": "string", + "maxLength": 100 + }, + "embedding_dimensions": { + "type": "integer" + }, + "embedding_level": { + "type": "string", + "enum": [ + "question", + "answer" + ] + }, + "embedding_model": { + "type": "string", + "maxLength": 100 + }, "model": { "type": "string", "maxLength": 100 @@ -11719,6 +11737,11 @@ const docTemplate = `{ "provider": { "type": "string", "maxLength": 50 + }, + "similarity_threshold": { + "type": "number", + "maximum": 1, + "minimum": 0 } } }, diff --git a/docs/swagger.json b/docs/swagger.json index dac2b38fd..71e802dfa 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -11685,6 +11685,24 @@ "type": "string", "maxLength": 256 }, + "embedding_crontab": { + "type": "string", + "maxLength": 100 + }, + "embedding_dimensions": { + "type": "integer" + }, + "embedding_level": { + "type": "string", + "enum": [ + "question", + "answer" + ] + }, + "embedding_model": { + "type": "string", + "maxLength": 100 + }, "model": { "type": "string", "maxLength": 100 @@ -11692,6 +11710,11 @@ "provider": { "type": "string", "maxLength": 50 + }, + "similarity_threshold": { + "type": "number", + "maximum": 1, + "minimum": 0 } } }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 7a7adb681..5dc68d6c4 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -2250,12 +2250,29 @@ definitions: api_key: maxLength: 256 type: string + embedding_crontab: + maxLength: 100 + type: string + embedding_dimensions: + type: integer + embedding_level: + enum: + - question + - answer + type: string + embedding_model: + maxLength: 100 + type: string model: maxLength: 100 type: string provider: maxLength: 50 type: string + similarity_threshold: + maximum: 1 + minimum: 0 + type: number type: object schema.SiteAIReq: properties: diff --git a/i18n/en_US.yaml b/i18n/en_US.yaml index 9a0d198b3..ec385e763 100644 --- a/i18n/en_US.yaml +++ b/i18n/en_US.yaml @@ -2355,6 +2355,24 @@ ui: label: Model msg: Model is required add_success: AI settings updated successfully. + embedding_settings: Embedding Settings + embedding_model: + label: Embedding model + text: "The model used to generate vector embeddings for semantic search (e.g. text-embedding-3-small)." + embedding_dimensions: + label: Embedding dimensions + text: "The number of dimensions for the embedding vectors (e.g. 1536 for text-embedding-3-small)." + embedding_level: + label: Embedding level + text: "Choose whether to create embeddings at the question level (question + all answers + comments) or answer level (each answer separately)." + question: Question level + answer: Answer level + embedding_crontab: + label: Embedding schedule + text: "Cron expression for periodic embedding calculation (e.g. '0 */6 * * *' for every 6 hours). Leave empty to disable automatic indexing." + similarity_threshold: + label: Similarity threshold + text: "Minimum cosine similarity score (0-1) for semantic search results. Only results with a score above this threshold will be returned. Default is 0 (no filtering)." conversations: topic: Topic helpful: Helpful diff --git a/i18n/zh_CN.yaml b/i18n/zh_CN.yaml index f16ed9fad..191619dd1 100644 --- a/i18n/zh_CN.yaml +++ b/i18n/zh_CN.yaml @@ -2319,6 +2319,24 @@ ui: label: 模型 msg: 模型是必需的 add_success: AI 设置更新成功。 + embedding_settings: Embedding 设置 + embedding_model: + label: Embedding 模型 + text: "用于生成语义搜索向量 Embedding 的模型(例如 text-embedding-3-small)。" + embedding_dimensions: + label: Embedding 维度 + text: "Embedding 向量的维度数(例如 text-embedding-3-small 为 1536)。" + embedding_level: + label: Embedding 级别 + text: "选择在问题级别(问题 + 所有回答 + 评论)还是回答级别(每个回答单独)创建 Embedding。" + question: 问题级别 + answer: 回答级别 + embedding_crontab: + label: Embedding 计划 + text: "定期计算 Embedding 的 Cron 表达式(例如 '0 */6 * * *' 表示每 6 小时)。留空则禁用自动索引。" + similarity_threshold: + label: 相似度阈值 + text: "语义搜索结果的最低余弦相似度分数(0-1)。只有分数高于此阈值的结果才会被返回。默认值为 0(不过滤)。" conversations: topic: 主题 helpful: 有帮助 diff --git a/internal/base/constant/ai_config.go b/internal/base/constant/ai_config.go index aa733bbaf..a25e47a45 100644 --- a/internal/base/constant/ai_config.go +++ b/internal/base/constant/ai_config.go @@ -33,6 +33,7 @@ const ( - get_tags: 搜索标签信息 - get_tag_detail: 获取特定标签的详细信息 - get_user: 搜索用户信息 +- semantic_search: 通过语义相似度搜索问题和答案。当用户的问题与现有内容概念相关但可能不匹配确切关键词时使用此工具。当 get_questions 关键词搜索返回较差结果时,请使用 semantic_search。 请根据用户的问题智能地使用这些工具来提供准确的答案。如果需要查询系统信息,请先使用相应的工具获取数据。` DefaultAIPromptConfigEnUS = `You are an intelligent assistant that can help users query information in the system. User question: %s @@ -44,6 +45,7 @@ You can use the following tools to query system information: - get_tags: Search for tag information - get_tag_detail: Get detailed information about a specific tag - get_user: Search for user information +- semantic_search: Search questions and answers by semantic meaning. Use this when the user's question relates conceptually to existing content but may not match exact keywords. When get_questions keyword search returns poor results, use semantic_search instead. Please intelligently use these tools based on the user's question to provide accurate answers. If you need to query system information, please use the appropriate tools to get the data first.` ) diff --git a/internal/controller/ai_controller.go b/internal/controller/ai_controller.go index e020ed30e..125cdab22 100644 --- a/internal/controller/ai_controller.go +++ b/internal/controller/ai_controller.go @@ -446,6 +446,7 @@ func (c *AIController) handleAIConversation(ctx *gin.Context, w http.ResponseWri toolCalls, newMessages, finished, aiResponse := c.processAIStream(ctx, w, id, conversationCtx.Model, client, aiReq, messages) messages = newMessages + log.Debugf("Round %d: toolCalls=%v", round+1, toolCalls) if aiResponse != "" { conversationCtx.Messages = append(conversationCtx.Messages, &ai_conversation.ConversationMessage{ Role: "assistant", @@ -497,6 +498,10 @@ func (c *AIController) processAIStream( break } + if len(response.Choices) == 0 { + continue + } + choice := response.Choices[0] if len(choice.Delta.ToolCalls) > 0 { @@ -735,6 +740,8 @@ func (c *AIController) callMCPTool(ctx context.Context, toolName string, argumen result, err = c.mcpController.MCPTagDetailsHandler()(ctx, request) case "get_user": result, err = c.mcpController.MCPUserDetailsHandler()(ctx, request) + case "semantic_search": + result, err = c.mcpController.MCPSemanticSearchHandler()(ctx, request) default: return "", fmt.Errorf("unknown tool: %s", toolName) } diff --git a/internal/controller/mcp_controller.go b/internal/controller/mcp_controller.go index d52f57979..fecdbef60 100644 --- a/internal/controller/mcp_controller.go +++ b/internal/controller/mcp_controller.go @@ -31,6 +31,7 @@ import ( answercommon "github.com/apache/answer/internal/service/answer_common" "github.com/apache/answer/internal/service/comment" "github.com/apache/answer/internal/service/content" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/feature_toggle" questioncommon "github.com/apache/answer/internal/service/question_common" "github.com/apache/answer/internal/service/siteinfo_common" @@ -49,6 +50,7 @@ type MCPController struct { userCommon *usercommon.UserCommon answerRepo answercommon.AnswerRepo featureToggleSvc *feature_toggle.FeatureToggleService + embeddingService *embedding.EmbeddingService } // NewMCPController new site info controller. @@ -61,6 +63,7 @@ func NewMCPController( userCommon *usercommon.UserCommon, answerRepo answercommon.AnswerRepo, featureToggleSvc *feature_toggle.FeatureToggleService, + embeddingService *embedding.EmbeddingService, ) *MCPController { return &MCPController{ searchService: searchService, @@ -71,6 +74,7 @@ func NewMCPController( userCommon: userCommon, answerRepo: answerRepo, featureToggleSvc: featureToggleSvc, + embeddingService: embeddingService, } } @@ -349,3 +353,131 @@ func (c *MCPController) MCPUserDetailsHandler() func(ctx context.Context, reques return mcp.NewToolResultText(string(res)), nil } } + +func (c *MCPController) MCPSemanticSearchHandler() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := c.ensureMCPEnabled(ctx); err != nil { + return nil, err + } + cond := schema.NewMCPSemanticSearchCond(request) + if len(cond.Query) == 0 { + return mcp.NewToolResultText("Query is required for semantic search."), nil + } + + siteGeneral, err := c.siteInfoService.GetSiteGeneral(ctx) + if err != nil { + log.Errorf("get site general info failed: %v", err) + return nil, err + } + + results, err := c.embeddingService.SearchSimilar(ctx, cond.Query, cond.TopK) + if err != nil { + log.Errorf("semantic search failed: %v", err) + return mcp.NewToolResultText("Semantic search is not available. Embedding may not be configured."), nil + } + if len(results) == 0 { + return mcp.NewToolResultText("No semantically similar content found."), nil + } + + resp := make([]*schema.MCPSemanticSearchResp, 0, len(results)) + for _, r := range results { + var meta entity.EmbeddingMetadata + _ = json.Unmarshal([]byte(r.Metadata), &meta) + + item := &schema.MCPSemanticSearchResp{ + ObjectID: r.ObjectID, + ObjectType: r.ObjectType, + Score: r.Score, + } + + // Compose link from metadata + if r.ObjectType == "answer" && meta.AnswerID != "" { + item.Link = fmt.Sprintf("%s/questions/%s/%s", siteGeneral.SiteUrl, meta.QuestionID, meta.AnswerID) + } else { + item.Link = fmt.Sprintf("%s/questions/%s", siteGeneral.SiteUrl, meta.QuestionID) + } + + // Query content from DB using IDs stored in metadata + if r.ObjectType == "question" { + question, qErr := c.questioncommon.Info(ctx, meta.QuestionID, "") + if qErr != nil { + log.Warnf("get question %s for semantic search failed: %v", meta.QuestionID, qErr) + } else { + item.Title = question.Title + item.Content = question.Content + } + + // Fetch answers by ID from metadata + for _, a := range meta.Answers { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, a.AnswerID) + if aErr != nil || !exist { + continue + } + answerItem := &schema.MCPSemanticSearchAnswer{ + AnswerID: a.AnswerID, + Content: answerEntity.OriginalText, + } + // Fetch comments on this answer from DB + for _, ac := range a.Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, ac.CommentID) + if cErr == nil && cExist { + answerItem.Comments = append(answerItem.Comments, &schema.MCPSemanticSearchComment{ + CommentID: ac.CommentID, + Content: cmt.OriginalText, + }) + } + } + item.Answers = append(item.Answers, answerItem) + } + + // Fetch question comments from DB + for _, qc := range meta.Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, qc.CommentID) + if cErr == nil && cExist { + item.Comments = append(item.Comments, &schema.MCPSemanticSearchComment{ + CommentID: qc.CommentID, + Content: cmt.OriginalText, + }) + } + } + } else if r.ObjectType == "answer" { + // Fetch question title for context + question, qErr := c.questioncommon.Info(ctx, meta.QuestionID, "") + if qErr == nil { + item.Title = question.Title + } + + // Fetch answer content from DB + if meta.AnswerID != "" { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, meta.AnswerID) + if aErr == nil && exist { + item.Content = answerEntity.OriginalText + } + } else if len(meta.Answers) > 0 { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, meta.Answers[0].AnswerID) + if aErr == nil && exist { + item.Content = answerEntity.OriginalText + } + } + + // Fetch answer comments from DB + if len(meta.Answers) > 0 { + for _, ac := range meta.Answers[0].Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, ac.CommentID) + if cErr == nil && cExist { + item.Comments = append(item.Comments, &schema.MCPSemanticSearchComment{ + CommentID: ac.CommentID, + Content: cmt.OriginalText, + }) + } + } + } + } + + resp = append(resp, item) + } + + data, _ := json.Marshal(resp) + return mcp.NewToolResultText(string(data)), nil + } +} diff --git a/internal/entity/embedding.go b/internal/entity/embedding.go new file mode 100644 index 000000000..3ea500d92 --- /dev/null +++ b/internal/entity/embedding.go @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package entity + +import "time" + +// Embedding stores vector embeddings for questions or answers. +type Embedding struct { + ID int `xorm:"not null pk autoincr INT(11) id"` + CreatedAt time.Time `xorm:"created not null default CURRENT_TIMESTAMP TIMESTAMP created_at"` + UpdatedAt time.Time `xorm:"updated not null default CURRENT_TIMESTAMP TIMESTAMP updated_at"` + ObjectID string `xorm:"not null BIGINT(20) INDEX object_id unique(object_embedding)"` + ObjectType string `xorm:"not null default '' VARCHAR(20) object_type unique(object_embedding)"` + ContentHash string `xorm:"not null default '' VARCHAR(64) content_hash"` + Metadata string `xorm:"not null MEDIUMTEXT metadata"` + Embedding string `xorm:"not null MEDIUMTEXT embedding"` + Dimensions int `xorm:"not null default 0 INT(11) dimensions"` +} + +// TableName returns the table name +func (Embedding) TableName() string { + return "embedding" +} + +// EmbeddingMetadata holds IDs for URI composition and content retrieval at query time. +type EmbeddingMetadata struct { + QuestionID string `json:"question_id"` + AnswerID string `json:"answer_id,omitempty"` + Answers []EmbeddingMetadataAnswer `json:"answers,omitempty"` + Comments []EmbeddingMetadataComment `json:"comments,omitempty"` +} + +// EmbeddingMetadataAnswer stores answer ID and comment IDs in metadata. +type EmbeddingMetadataAnswer struct { + AnswerID string `json:"answer_id"` + Comments []EmbeddingMetadataComment `json:"comments,omitempty"` +} + +// EmbeddingMetadataComment stores comment ID in metadata. +type EmbeddingMetadataComment struct { + CommentID string `json:"comment_id"` +} diff --git a/internal/migrations/init_data.go b/internal/migrations/init_data.go index 5af41bbfc..e65c9f11b 100644 --- a/internal/migrations/init_data.go +++ b/internal/migrations/init_data.go @@ -79,6 +79,7 @@ var ( &entity.APIKey{}, &entity.AIConversation{}, &entity.AIConversationRecord{}, + &entity.Embedding{}, } roles = []*entity.Role{ diff --git a/internal/migrations/migrations.go b/internal/migrations/migrations.go index 682a7b207..e6722cc08 100644 --- a/internal/migrations/migrations.go +++ b/internal/migrations/migrations.go @@ -108,6 +108,7 @@ var migrations = []Migration{ NewMigration("v1.8.0", "change admin menu", updateAdminMenuSettings, true), NewMigration("v1.8.1", "ai feat", aiFeat, true), NewMigration("v2.0.1", "change avatar type to text", updateAvatarType, false), + NewMigration("v2.0.2", "add embedding table", addEmbeddingTable, false), } func GetMigrations() []Migration { diff --git a/internal/migrations/v32.go b/internal/migrations/v32.go index fc6614b11..438097348 100644 --- a/internal/migrations/v32.go +++ b/internal/migrations/v32.go @@ -24,6 +24,7 @@ import ( "fmt" "github.com/apache/answer/internal/entity" + "github.com/segmentfault/pacman/log" "xorm.io/xorm" ) @@ -35,3 +36,11 @@ func updateAvatarType(ctx context.Context, x *xorm.Engine) error { } return nil } + +func addEmbeddingTable(ctx context.Context, x *xorm.Engine) error { + if err := x.Context(ctx).Sync(new(entity.Embedding)); err != nil { + return fmt.Errorf("sync embedding table failed: %w", err) + } + log.Info("Embedding table migration completed successfully") + return nil +} diff --git a/internal/repo/embedding/embedding_repo.go b/internal/repo/embedding/embedding_repo.go new file mode 100644 index 000000000..67bde1d5d --- /dev/null +++ b/internal/repo/embedding/embedding_repo.go @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package embedding + +import ( + "context" + "encoding/json" + "math" + "sort" + + "github.com/apache/answer/internal/base/data" + "github.com/apache/answer/internal/entity" + "github.com/segmentfault/pacman/log" + "xorm.io/builder" +) + +// EmbeddingRepo defines the interface for embedding data access. +type EmbeddingRepo interface { + Upsert(ctx context.Context, emb *entity.Embedding) error + GetByObjectID(ctx context.Context, objectID, objectType string) (*entity.Embedding, bool, error) + GetAll(ctx context.Context) ([]*entity.Embedding, error) + SearchSimilar(ctx context.Context, queryVector []float32, topK int) ([]SimilarResult, error) + DeleteByObjectID(ctx context.Context, objectID, objectType string) error + Count(ctx context.Context) (int64, error) +} + +// SimilarResult holds a similarity search result. +type SimilarResult struct { + ObjectID string `json:"object_id"` + ObjectType string `json:"object_type"` + Metadata string `json:"metadata"` + Score float64 `json:"score"` +} + +type embeddingRepo struct { + data *data.Data +} + +// NewEmbeddingRepo creates a new EmbeddingRepo. +func NewEmbeddingRepo(data *data.Data) EmbeddingRepo { + return &embeddingRepo{data: data} +} + +// Upsert inserts or updates an embedding by (object_id, object_type). +func (r *embeddingRepo) Upsert(ctx context.Context, emb *entity.Embedding) error { + existing := &entity.Embedding{} + exist, err := r.data.DB.Context(ctx). + Where(builder.Eq{"object_id": emb.ObjectID, "object_type": emb.ObjectType}). + Get(existing) + if err != nil { + log.Errorf("check embedding existence failed: %v", err) + return err + } + + if exist { + emb.ID = existing.ID + _, err = r.data.DB.Context(ctx).ID(existing.ID). + Cols("content_hash", "metadata", "embedding", "dimensions", "updated_at"). + Update(emb) + if err != nil { + log.Errorf("update embedding failed: %v", err) + return err + } + return nil + } + + _, err = r.data.DB.Context(ctx).Insert(emb) + if err != nil { + log.Errorf("insert embedding failed: %v", err) + return err + } + return nil +} + +// GetByObjectID returns an embedding by object ID and type. +func (r *embeddingRepo) GetByObjectID(ctx context.Context, objectID, objectType string) (*entity.Embedding, bool, error) { + emb := &entity.Embedding{} + exist, err := r.data.DB.Context(ctx). + Where(builder.Eq{"object_id": objectID, "object_type": objectType}). + Get(emb) + if err != nil { + log.Errorf("get embedding failed: %v", err) + return nil, false, err + } + return emb, exist, nil +} + +// GetAll returns all embeddings. +func (r *embeddingRepo) GetAll(ctx context.Context) ([]*entity.Embedding, error) { + var list []*entity.Embedding + err := r.data.DB.Context(ctx).Find(&list) + if err != nil { + log.Errorf("get all embeddings failed: %v", err) + return nil, err + } + return list, nil +} + +// SearchSimilar performs brute-force cosine similarity search in Go. +func (r *embeddingRepo) SearchSimilar(ctx context.Context, queryVector []float32, topK int) ([]SimilarResult, error) { + allEmbeddings, err := r.GetAll(ctx) + if err != nil { + return nil, err + } + + type scored struct { + emb *entity.Embedding + score float64 + } + results := make([]scored, 0, len(allEmbeddings)) + + for _, emb := range allEmbeddings { + var vec []float32 + if err := json.Unmarshal([]byte(emb.Embedding), &vec); err != nil { + log.Warnf("skip embedding id=%d, unmarshal failed: %v", emb.ID, err) + continue + } + score := cosineSimilarity(queryVector, vec) + results = append(results, scored{emb: emb, score: score}) + } + + sort.Slice(results, func(i, j int) bool { + return results[i].score > results[j].score + }) + + if topK > len(results) { + topK = len(results) + } + + out := make([]SimilarResult, 0, topK) + for i := 0; i < topK; i++ { + out = append(out, SimilarResult{ + ObjectID: results[i].emb.ObjectID, + ObjectType: results[i].emb.ObjectType, + Metadata: results[i].emb.Metadata, + Score: results[i].score, + }) + } + return out, nil +} + +// DeleteByObjectID deletes an embedding by object ID and type. +func (r *embeddingRepo) DeleteByObjectID(ctx context.Context, objectID, objectType string) error { + _, err := r.data.DB.Context(ctx). + Where(builder.Eq{"object_id": objectID, "object_type": objectType}). + Delete(&entity.Embedding{}) + if err != nil { + log.Errorf("delete embedding failed: %v", err) + return err + } + return nil +} + +// Count returns the total number of embeddings. +func (r *embeddingRepo) Count(ctx context.Context) (int64, error) { + count, err := r.data.DB.Context(ctx).Count(&entity.Embedding{}) + if err != nil { + log.Errorf("count embeddings failed: %v", err) + return 0, err + } + return count, nil +} + +// cosineSimilarity computes cosine similarity between two vectors. +func cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) || len(a) == 0 { + return 0 + } + var dotProduct, normA, normB float64 + for i := range a { + dotProduct += float64(a[i]) * float64(b[i]) + normA += float64(a[i]) * float64(a[i]) + normB += float64(b[i]) * float64(b[i]) + } + denom := math.Sqrt(normA) * math.Sqrt(normB) + if denom == 0 { + return 0 + } + return dotProduct / denom +} diff --git a/internal/repo/provider.go b/internal/repo/provider.go index 510a94aaa..2a9717d00 100644 --- a/internal/repo/provider.go +++ b/internal/repo/provider.go @@ -34,6 +34,7 @@ import ( "github.com/apache/answer/internal/repo/collection" "github.com/apache/answer/internal/repo/comment" "github.com/apache/answer/internal/repo/config" + "github.com/apache/answer/internal/repo/embedding" "github.com/apache/answer/internal/repo/export" "github.com/apache/answer/internal/repo/file_record" "github.com/apache/answer/internal/repo/limit" @@ -113,4 +114,5 @@ var ProviderSetRepo = wire.NewSet( file_record.NewFileRecordRepo, api_key.NewAPIKeyRepo, ai_conversation.NewAIConversationRepo, + embedding.NewEmbeddingRepo, ) diff --git a/internal/schema/mcp_schema.go b/internal/schema/mcp_schema.go index bead21c9d..9afee72ec 100644 --- a/internal/schema/mcp_schema.go +++ b/internal/schema/mcp_schema.go @@ -27,15 +27,17 @@ import ( ) const ( - MCPSearchCondKeyword = "keyword" - MCPSearchCondUsername = "username" - MCPSearchCondScore = "score" - MCPSearchCondTag = "tag" - MCPSearchCondPage = "page" - MCPSearchCondPageSize = "page_size" - MCPSearchCondTagName = "tag_name" - MCPSearchCondQuestionID = "question_id" - MCPSearchCondObjectID = "object_id" + MCPSearchCondKeyword = "keyword" + MCPSearchCondUsername = "username" + MCPSearchCondScore = "score" + MCPSearchCondTag = "tag" + MCPSearchCondPage = "page" + MCPSearchCondPageSize = "page_size" + MCPSearchCondTagName = "tag_name" + MCPSearchCondQuestionID = "question_id" + MCPSearchCondObjectID = "object_id" + MCPSearchCondSemanticQuery = "query" + MCPSearchCondTopK = "top_k" ) type MCPSearchCond struct { @@ -98,6 +100,48 @@ type MCPSearchCommentInfoResp struct { Link string `json:"link"` } +// MCPSemanticSearchCond is the condition for semantic search. +type MCPSemanticSearchCond struct { + Query string `json:"query"` + TopK int `json:"top_k"` +} + +// MCPSemanticSearchResp is a single semantic search result. +type MCPSemanticSearchResp struct { + ObjectID string `json:"object_id"` + ObjectType string `json:"object_type"` + Title string `json:"title"` + Content string `json:"content"` + Score float64 `json:"score"` + Link string `json:"link"` + Answers []*MCPSemanticSearchAnswer `json:"answers,omitempty"` + Comments []*MCPSemanticSearchComment `json:"comments,omitempty"` +} + +// MCPSemanticSearchAnswer is an answer in a semantic search result. +type MCPSemanticSearchAnswer struct { + AnswerID string `json:"answer_id"` + Content string `json:"content"` + Comments []*MCPSemanticSearchComment `json:"comments,omitempty"` +} + +// MCPSemanticSearchComment is a comment in a semantic search result. +type MCPSemanticSearchComment struct { + CommentID string `json:"comment_id"` + Content string `json:"content"` +} + +func NewMCPSemanticSearchCond(request mcp.CallToolRequest) *MCPSemanticSearchCond { + cond := &MCPSemanticSearchCond{TopK: 5} + if query, ok := getRequestValue(request, MCPSearchCondSemanticQuery); ok { + cond.Query = query + } + if topK, ok := getRequestNumber(request, MCPSearchCondTopK); ok && topK > 0 { + cond.TopK = topK + } + return cond +} + func NewMCPSearchCond(request mcp.CallToolRequest) *MCPSearchCond { cond := &MCPSearchCond{} if keyword, ok := getRequestValue(request, MCPSearchCondKeyword); ok { diff --git a/internal/schema/mcp_tools/mcp_tools.go b/internal/schema/mcp_tools/mcp_tools.go index 949a738c7..3ae6b3bea 100644 --- a/internal/schema/mcp_tools/mcp_tools.go +++ b/internal/schema/mcp_tools/mcp_tools.go @@ -32,6 +32,7 @@ var ( NewTagsTool(), NewTagDetailTool(), NewUserTool(), + NewSemanticSearchTool(), } ) @@ -103,3 +104,17 @@ func NewUserTool() mcp.Tool { ) return listFilesTool } + +func NewSemanticSearchTool() mcp.Tool { + tool := mcp.NewTool("semantic_search", + mcp.WithDescription("Search questions and answers by semantic meaning. Use this when the user's question relates conceptually to existing content but may not match exact keywords. Returns the most semantically similar content."), + mcp.WithString(schema.MCPSearchCondSemanticQuery, + mcp.Required(), + mcp.Description("The search query text to find semantically similar questions and answers"), + ), + mcp.WithNumber(schema.MCPSearchCondTopK, + mcp.Description("Maximum number of results to return (default 5)"), + ), + ) + return tool +} diff --git a/internal/schema/siteinfo_schema.go b/internal/schema/siteinfo_schema.go index bdf2308d3..84bc54e30 100644 --- a/internal/schema/siteinfo_schema.go +++ b/internal/schema/siteinfo_schema.go @@ -281,10 +281,15 @@ func (s *SiteAIResp) GetProvider() *SiteAIProvider { } type SiteAIProvider struct { - Provider string `validate:"omitempty,lte=50" form:"provider" json:"provider"` - APIHost string `validate:"omitempty,lte=512" form:"api_host" json:"api_host"` - APIKey string `validate:"omitempty,lte=256" form:"api_key" json:"api_key"` - Model string `validate:"omitempty,lte=100" form:"model" json:"model"` + Provider string `validate:"omitempty,lte=50" form:"provider" json:"provider"` + APIHost string `validate:"omitempty,lte=512" form:"api_host" json:"api_host"` + APIKey string `validate:"omitempty,lte=256" form:"api_key" json:"api_key"` + Model string `validate:"omitempty,lte=100" form:"model" json:"model"` + EmbeddingModel string `validate:"omitempty,lte=100" form:"embedding_model" json:"embedding_model"` + EmbeddingDimensions int `validate:"omitempty" form:"embedding_dimensions" json:"embedding_dimensions"` + EmbeddingLevel string `validate:"omitempty,oneof=question answer" form:"embedding_level" json:"embedding_level"` + EmbeddingCrontab string `validate:"omitempty,lte=100" form:"embedding_crontab" json:"embedding_crontab"` + SimilarityThreshold float64 `validate:"omitempty,gte=0,lte=1" form:"similarity_threshold" json:"similarity_threshold"` } // SiteAIResp AI configuration response diff --git a/internal/service/embedding/embedding_service.go b/internal/service/embedding/embedding_service.go new file mode 100644 index 000000000..72e008abb --- /dev/null +++ b/internal/service/embedding/embedding_service.go @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package embedding + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/apache/answer/internal/base/pager" + "github.com/apache/answer/internal/entity" + embeddingRepo "github.com/apache/answer/internal/repo/embedding" + "github.com/apache/answer/internal/schema" + "github.com/apache/answer/internal/service/comment" + "github.com/apache/answer/internal/service/content" + questioncommon "github.com/apache/answer/internal/service/question_common" + "github.com/apache/answer/internal/service/siteinfo_common" + "github.com/robfig/cron/v3" + "github.com/sashabaranov/go-openai" + "github.com/segmentfault/pacman/log" +) + +const ( + EmbeddingLevelQuestion = "question" + EmbeddingLevelAnswer = "answer" +) + +// EmbeddingService handles embedding generation, text aggregation, and indexing. +type EmbeddingService struct { + embeddingRepo embeddingRepo.EmbeddingRepo + searchService *content.SearchService + answerService *content.AnswerService + questionCommon *questioncommon.QuestionCommon + commentRepo comment.CommentRepo + siteInfoService siteinfo_common.SiteInfoCommonService + + mu sync.Mutex + cronJob *cron.Cron + cronSpec string +} + +// NewEmbeddingService creates a new EmbeddingService. +func NewEmbeddingService( + embeddingRepo embeddingRepo.EmbeddingRepo, + searchService *content.SearchService, + answerService *content.AnswerService, + questionCommon *questioncommon.QuestionCommon, + commentRepo comment.CommentRepo, + siteInfoService siteinfo_common.SiteInfoCommonService, +) *EmbeddingService { + return &EmbeddingService{ + embeddingRepo: embeddingRepo, + searchService: searchService, + answerService: answerService, + questionCommon: questionCommon, + commentRepo: commentRepo, + siteInfoService: siteInfoService, + } +} + +// getAIConfig returns the current AI configuration. +func (s *EmbeddingService) getAIConfig(ctx context.Context) (*schema.SiteAIResp, *schema.SiteAIProvider, error) { + aiConfig, err := s.siteInfoService.GetSiteAI(ctx) + if err != nil { + return nil, nil, fmt.Errorf("get AI config failed: %w", err) + } + if !aiConfig.Enabled { + return nil, nil, fmt.Errorf("AI feature is disabled") + } + provider := aiConfig.GetProvider() + if provider.EmbeddingModel == "" { + return nil, nil, fmt.Errorf("embedding model not configured") + } + return aiConfig, provider, nil +} + +// createEmbeddingClient creates an OpenAI-compatible client for embedding requests. +func (s *EmbeddingService) createEmbeddingClient(provider *schema.SiteAIProvider) *openai.Client { + config := openai.DefaultConfig(provider.APIKey) + config.BaseURL = provider.APIHost + if !strings.HasSuffix(config.BaseURL, "/v1") { + config.BaseURL += "/v1" + } + return openai.NewClientWithConfig(config) +} + +// GenerateEmbedding generates an embedding vector for the given text. +func (s *EmbeddingService) GenerateEmbedding(ctx context.Context, text string) ([]float32, error) { + _, provider, err := s.getAIConfig(ctx) + if err != nil { + return nil, err + } + + client := s.createEmbeddingClient(provider) + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ + Input: []string{text}, + Model: openai.EmbeddingModel(provider.EmbeddingModel), + }) + if err != nil { + return nil, fmt.Errorf("create embeddings failed: %w", err) + } + if len(resp.Data) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return resp.Data[0].Embedding, nil +} + +// ComputeContentHash computes SHA256 of the text. +func ComputeContentHash(text string) string { + h := sha256.Sum256([]byte(text)) + return fmt.Sprintf("%x", h) +} + +// BuildTextForQuestion aggregates question title + body + all answers + comments into one text. +// Uses SearchService and QuestionCommon to respect the plugin architecture. +func (s *EmbeddingService) BuildTextForQuestion(ctx context.Context, questionID string) (text string, meta *entity.EmbeddingMetadata, err error) { + // Get question detail via service layer + question, err := s.questionCommon.Info(ctx, questionID, "") + if err != nil { + return "", nil, fmt.Errorf("get question info failed: %w", err) + } + + meta = &entity.EmbeddingMetadata{ + QuestionID: questionID, + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s\n%s", question.Title, question.Content)) + + // Get answers via AnswerService + answerInfoList, _, err := s.answerService.SearchList(ctx, &schema.AnswerListReq{ + QuestionID: questionID, + Page: 1, + PageSize: 50, + }) + if err != nil { + log.Warnf("get answers for question %s failed: %v", questionID, err) + } else { + for _, a := range answerInfoList { + parts = append(parts, fmt.Sprintf("Answer: %s", a.Content)) + answerMeta := entity.EmbeddingMetadataAnswer{ + AnswerID: a.ID, + } + + // Get comments on this answer + answerComments, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ + PageCond: pager.PageCond{Page: 1, PageSize: 50}, + ObjectID: a.ID, + QueryCond: "newest", + }) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", a.ID, err) + } else { + for _, c := range answerComments { + parts = append(parts, fmt.Sprintf("Comment on answer: %s", c.OriginalText)) + answerMeta.Comments = append(answerMeta.Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } + } + meta.Answers = append(meta.Answers, answerMeta) + } + } + + // Get comments on the question + commentList, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ + PageCond: pager.PageCond{Page: 1, PageSize: 50}, + ObjectID: questionID, + QueryCond: "newest", + }) + if err != nil { + log.Warnf("get comments for question %s failed: %v", questionID, err) + } else { + for _, c := range commentList { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + meta.Comments = append(meta.Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } + } + + return strings.Join(parts, "\n\n"), meta, nil +} + +// BuildTextForAnswer aggregates answer body + parent question title + answer comments into one text. +func (s *EmbeddingService) BuildTextForAnswer(ctx context.Context, answerID, questionID string) (text string, meta *entity.EmbeddingMetadata, err error) { + // Get parent question title + question, err := s.questionCommon.Info(ctx, questionID, "") + if err != nil { + return "", nil, fmt.Errorf("get question info for answer failed: %w", err) + } + + meta = &entity.EmbeddingMetadata{ + QuestionID: questionID, + AnswerID: answerID, + } + + // Get the specific answer's content via AnswerService + answerInfo, err := s.answerService.GetDetail(ctx, answerID) + if err != nil { + return "", nil, fmt.Errorf("get answer failed: %w", err) + } + + var answerText string + if answerInfo != nil { + answerText = answerInfo.Content + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s", question.Title)) + if answerText != "" { + parts = append(parts, fmt.Sprintf("Answer: %s", answerText)) + meta.Answers = append(meta.Answers, entity.EmbeddingMetadataAnswer{ + AnswerID: answerID, + }) + } + + // Get comments on the answer + commentList, _, err := s.commentRepo.GetCommentPage(ctx, &comment.CommentQuery{ + PageCond: pager.PageCond{Page: 1, PageSize: 50}, + ObjectID: answerID, + QueryCond: "newest", + }) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", answerID, err) + } else { + for _, c := range commentList { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + if len(meta.Answers) > 0 { + meta.Answers[0].Comments = append(meta.Answers[0].Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } else { + meta.Comments = append(meta.Comments, entity.EmbeddingMetadataComment{ + CommentID: c.ID, + }) + } + } + } + + return strings.Join(parts, "\n\n"), meta, nil +} + +// IndexQuestion indexes a single question embedding. +func (s *EmbeddingService) IndexQuestion(ctx context.Context, questionID string) error { + text, meta, err := s.BuildTextForQuestion(ctx, questionID) + if err != nil { + return err + } + + contentHash := ComputeContentHash(text) + + // Check staleness + existing, exist, _ := s.embeddingRepo.GetByObjectID(ctx, questionID, EmbeddingLevelQuestion) + if exist && existing.ContentHash == contentHash { + return nil // already up to date + } + + vec, err := s.GenerateEmbedding(ctx, text) + if err != nil { + return fmt.Errorf("generate embedding for question %s failed: %w", questionID, err) + } + + metaJSON, _ := json.Marshal(meta) + vecJSON, _ := json.Marshal(vec) + + return s.embeddingRepo.Upsert(ctx, &entity.Embedding{ + ObjectID: questionID, + ObjectType: EmbeddingLevelQuestion, + ContentHash: contentHash, + Metadata: string(metaJSON), + Embedding: string(vecJSON), + Dimensions: len(vec), + }) +} + +// IndexAnswer indexes a single answer embedding. +func (s *EmbeddingService) IndexAnswer(ctx context.Context, answerID, questionID string) error { + text, meta, err := s.BuildTextForAnswer(ctx, answerID, questionID) + if err != nil { + return err + } + + contentHash := ComputeContentHash(text) + + existing, exist, _ := s.embeddingRepo.GetByObjectID(ctx, answerID, EmbeddingLevelAnswer) + if exist && existing.ContentHash == contentHash { + return nil + } + + vec, err := s.GenerateEmbedding(ctx, text) + if err != nil { + return fmt.Errorf("generate embedding for answer %s failed: %w", answerID, err) + } + + metaJSON, _ := json.Marshal(meta) + vecJSON, _ := json.Marshal(vec) + + return s.embeddingRepo.Upsert(ctx, &entity.Embedding{ + ObjectID: answerID, + ObjectType: EmbeddingLevelAnswer, + ContentHash: contentHash, + Metadata: string(metaJSON), + Embedding: string(vecJSON), + Dimensions: len(vec), + }) +} + +// SearchSimilar performs semantic search and returns top-K similar results. +// Results below the configured similarity threshold are filtered out. +func (s *EmbeddingService) SearchSimilar(ctx context.Context, query string, topK int) ([]embeddingRepo.SimilarResult, error) { + vec, err := s.GenerateEmbedding(ctx, query) + if err != nil { + return nil, fmt.Errorf("generate query embedding failed: %w", err) + } + results, err := s.embeddingRepo.SearchSimilar(ctx, vec, topK) + if err != nil { + return nil, err + } + + for _, r := range results { + log.Debugf("semantic search result: object_id=%s object_type=%s score=%.6f", r.ObjectID, r.ObjectType, r.Score) + } + + // Apply similarity threshold from config (default 0 means no filtering) + _, provider, cfgErr := s.getAIConfig(ctx) + if cfgErr == nil && provider.SimilarityThreshold > 0 { + filtered := make([]embeddingRepo.SimilarResult, 0, len(results)) + for _, r := range results { + if r.Score >= provider.SimilarityThreshold { + filtered = append(filtered, r) + } + } + log.Debugf("semantic search: %d/%d results passed threshold %.4f", len(filtered), len(results), provider.SimilarityThreshold) + return filtered, nil + } + + return results, nil +} + +// GetEmbeddingCount returns the total number of stored embeddings. +func (s *EmbeddingService) GetEmbeddingCount(ctx context.Context) (int64, error) { + return s.embeddingRepo.Count(ctx) +} + +// RemoveEmbedding removes an embedding by object ID and type. +func (s *EmbeddingService) RemoveEmbedding(ctx context.Context, objectID, objectType string) error { + return s.embeddingRepo.DeleteByObjectID(ctx, objectID, objectType) +} + +// IndexAll indexes all questions (and optionally answers) based on the configured embedding level. +func (s *EmbeddingService) IndexAll(ctx context.Context) error { + _, provider, err := s.getAIConfig(ctx) + if err != nil { + log.Warnf("embedding indexer: %v", err) + return err + } + + level := provider.EmbeddingLevel + if level == "" { + level = EmbeddingLevelQuestion + } + + log.Debugf("Starting embedding indexer at level: %s", level) + + page := 1 + totalIndexed := 0 + for { + searchResp, err := s.searchService.Search(ctx, &schema.SearchDTO{ + Query: "is:question", + Page: page, + Size: 50, + Order: "newest", + }) + if err != nil { + return fmt.Errorf("search questions for indexing failed: %w", err) + } + if searchResp == nil || len(searchResp.SearchResults) == 0 { + break + } + + for _, result := range searchResp.SearchResults { + if result.Object == nil { + continue + } + qID := result.Object.QuestionID + if level == EmbeddingLevelQuestion { + if err := s.IndexQuestion(ctx, qID); err != nil { + log.Warnf("index question %s failed: %v", qID, err) + continue + } + totalIndexed++ + } else if level == EmbeddingLevelAnswer { + // Index each answer for this question via AnswerService + answerInfoList, _, err := s.answerService.SearchList(ctx, &schema.AnswerListReq{ + QuestionID: qID, + Page: 1, + PageSize: 50, + }) + if err != nil { + log.Warnf("get answers for question %s failed: %v", qID, err) + continue + } + for _, a := range answerInfoList { + if err := s.IndexAnswer(ctx, a.ID, qID); err != nil { + log.Warnf("index answer %s failed: %v", a.ID, err) + continue + } + totalIndexed++ + } + } + } + + if int64((page)*50) >= searchResp.Total { + break + } + page++ + } + + log.Infof("Embedding indexer completed: %d items indexed", totalIndexed) + return nil +} + +// StartScheduler starts a cron job to periodically run IndexAll. +func (s *EmbeddingService) StartScheduler(spec string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Stop existing cron if running + if s.cronJob != nil { + s.cronJob.Stop() + s.cronJob = nil + s.cronSpec = "" + } + + if spec == "" { + return nil + } + + c := cron.New() + _, err := c.AddFunc(spec, func() { + ctx := context.Background() + log.Infof("embedding cron triggered (spec=%s)", spec) + if err := s.IndexAll(ctx); err != nil { + log.Errorf("embedding cron IndexAll failed: %v", err) + } + }) + if err != nil { + return fmt.Errorf("invalid cron expression %q: %w", spec, err) + } + + c.Start() + s.cronJob = c + s.cronSpec = spec + log.Infof("embedding scheduler started with cron: %s", spec) + return nil +} + +// StopScheduler stops the embedding cron scheduler. +func (s *EmbeddingService) StopScheduler() { + s.mu.Lock() + defer s.mu.Unlock() + if s.cronJob != nil { + s.cronJob.Stop() + s.cronJob = nil + s.cronSpec = "" + log.Infof("embedding scheduler stopped") + } +} + +// ApplyConfig reads the current AI config and starts or stops the scheduler accordingly. +func (s *EmbeddingService) ApplyConfig(ctx context.Context) { + aiConfig, provider, err := s.getAIConfig(ctx) + if err != nil || aiConfig == nil || provider == nil { + s.StopScheduler() + return + } + + if provider.EmbeddingModel == "" || provider.EmbeddingCrontab == "" { + s.StopScheduler() + return + } + + // Only restart if the cron spec changed + s.mu.Lock() + currentSpec := s.cronSpec + s.mu.Unlock() + + if currentSpec == provider.EmbeddingCrontab { + return + } + + if err := s.StartScheduler(provider.EmbeddingCrontab); err != nil { + log.Errorf("failed to start embedding scheduler: %v", err) + } +} diff --git a/internal/service/provider.go b/internal/service/provider.go index 3e43b0ae0..26f1c4309 100644 --- a/internal/service/provider.go +++ b/internal/service/provider.go @@ -36,6 +36,7 @@ import ( "github.com/apache/answer/internal/service/config" "github.com/apache/answer/internal/service/content" "github.com/apache/answer/internal/service/dashboard" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/eventqueue" "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/feature_toggle" @@ -134,4 +135,5 @@ var ProviderSetService = wire.NewSet( apikey.NewAPIKeyService, ai_conversation.NewAIConversationService, feature_toggle.NewFeatureToggleService, + embedding.NewEmbeddingService, ) diff --git a/internal/service/siteinfo/siteinfo_service.go b/internal/service/siteinfo/siteinfo_service.go index 1e25cbaa4..70003984c 100644 --- a/internal/service/siteinfo/siteinfo_service.go +++ b/internal/service/siteinfo/siteinfo_service.go @@ -33,6 +33,7 @@ import ( "github.com/apache/answer/internal/entity" "github.com/apache/answer/internal/schema" "github.com/apache/answer/internal/service/config" + embeddingService "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/file_record" questioncommon "github.com/apache/answer/internal/service/question_common" @@ -53,6 +54,7 @@ type SiteInfoService struct { configService *config.ConfigService questioncommon *questioncommon.QuestionCommon fileRecordService *file_record.FileRecordService + embeddingService *embeddingService.EmbeddingService } func NewSiteInfoService( @@ -63,7 +65,7 @@ func NewSiteInfoService( configService *config.ConfigService, questioncommon *questioncommon.QuestionCommon, fileRecordService *file_record.FileRecordService, - + embeddingSvc *embeddingService.EmbeddingService, ) *SiteInfoService { plugin.RegisterGetSiteURLFunc(func() string { generalSiteInfo, err := siteInfoCommonService.GetSiteGeneral(context.Background()) @@ -82,6 +84,7 @@ func NewSiteInfoService( configService: configService, questioncommon: questioncommon, fileRecordService: fileRecordService, + embeddingService: embeddingSvc, } } @@ -409,7 +412,13 @@ func (s *SiteInfoService) SaveSiteAI(ctx context.Context, req *schema.SiteAIReq) Content: string(content), Status: 1, } - return s.siteInfoRepo.SaveByType(ctx, constant.SiteTypeAI, siteInfo) + if err := s.siteInfoRepo.SaveByType(ctx, constant.SiteTypeAI, siteInfo); err != nil { + return err + } + + // Apply embedding scheduler config (start/stop cron based on settings) + go s.embeddingService.ApplyConfig(ctx) + return nil } func (s *SiteInfoService) maskAIKeys(resp *schema.SiteAIResp) { diff --git a/ui/src/common/interface.ts b/ui/src/common/interface.ts index 308726e80..b01b621f7 100644 --- a/ui/src/common/interface.ts +++ b/ui/src/common/interface.ts @@ -833,6 +833,11 @@ export interface AiConfig { api_host: string; api_key: string; model: string; + embedding_model: string; + embedding_dimensions: number; + embedding_level: string; + embedding_crontab: string; + similarity_threshold: number; }>; } diff --git a/ui/src/pages/Admin/AiSettings/index.tsx b/ui/src/pages/Admin/AiSettings/index.tsx index 2270aa5c5..de284ff00 100644 --- a/ui/src/pages/Admin/AiSettings/index.tsx +++ b/ui/src/pages/Admin/AiSettings/index.tsx @@ -68,6 +68,31 @@ const Index = () => { isInvalid: false, errorMsg: '', }, + embedding_model: { + value: '', + isInvalid: false, + errorMsg: '', + }, + embedding_dimensions: { + value: '', + isInvalid: false, + errorMsg: '', + }, + embedding_level: { + value: 'question', + isInvalid: false, + errorMsg: '', + }, + embedding_crontab: { + value: '', + isInvalid: false, + errorMsg: '', + }, + similarity_threshold: { + value: '0', + isInvalid: false, + errorMsg: '', + }, }); const [apiHostPlaceholder, setApiHostPlaceholder] = useState(''); const [modelsData, setModels] = useState<{ id: string }[]>([]); @@ -146,6 +171,31 @@ const Index = () => { isInvalid: false, errorMsg: '', }, + embedding_model: { + value: findHistoryProvider?.embedding_model || '', + isInvalid: false, + errorMsg: '', + }, + embedding_dimensions: { + value: String(findHistoryProvider?.embedding_dimensions || ''), + isInvalid: false, + errorMsg: '', + }, + embedding_level: { + value: findHistoryProvider?.embedding_level || 'question', + isInvalid: false, + errorMsg: '', + }, + embedding_crontab: { + value: findHistoryProvider?.embedding_crontab || '', + isInvalid: false, + errorMsg: '', + }, + similarity_threshold: { + value: String(findHistoryProvider?.similarity_threshold || '0'), + isInvalid: false, + errorMsg: '', + }, }); const provider = aiProviders?.find((item) => item.name === value); const host = findHistoryProvider?.api_host || provider?.default_api_host; @@ -218,6 +268,13 @@ const Index = () => { api_host: formData.api_host.value, api_key: formData.api_key.value, model: formData.model.value, + embedding_model: formData.embedding_model.value, + embedding_dimensions: + Number(formData.embedding_dimensions.value) || 0, + embedding_level: formData.embedding_level.value, + embedding_crontab: formData.embedding_crontab.value, + similarity_threshold: + Number(formData.similarity_threshold.value) || 0, }; } return v; @@ -295,6 +352,31 @@ const Index = () => { isInvalid: false, errorMsg: '', }, + embedding_model: { + value: currentAiConfig?.embedding_model || '', + isInvalid: false, + errorMsg: '', + }, + embedding_dimensions: { + value: String(currentAiConfig?.embedding_dimensions || ''), + isInvalid: false, + errorMsg: '', + }, + embedding_level: { + value: currentAiConfig?.embedding_level || 'question', + isInvalid: false, + errorMsg: '', + }, + embedding_crontab: { + value: currentAiConfig?.embedding_crontab || '', + isInvalid: false, + errorMsg: '', + }, + similarity_threshold: { + value: String(currentAiConfig?.similarity_threshold || '0'), + isInvalid: false, + errorMsg: '', + }, }); }; @@ -477,6 +559,99 @@ const Index = () => {
{formData.model.errorMsg}
+
+
{t('embedding_settings')}
+ + + {t('embedding_model.label')} + + handleValueChange({ + embedding_model: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + } + /> + + {t('embedding_model.text')} + + + + + {t('embedding_level.label')} + + handleValueChange({ + embedding_level: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + }> + + + + + {t('embedding_level.text')} + + + + + {t('embedding_crontab.label')} + + handleValueChange({ + embedding_crontab: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + } + /> + + {t('embedding_crontab.text')} + + + + + {t('similarity_threshold.label')} + + handleValueChange({ + similarity_threshold: { + value: e.target.value, + errorMsg: '', + isInvalid: false, + }, + }) + } + /> + + {t('similarity_threshold.text')} + + +