diff --git a/config.json b/config.json index 5a6fdef..8f6d29a 100644 --- a/config.json +++ b/config.json @@ -69,6 +69,16 @@ }, "server_url":"https://api.minimax.chat/v1/text/chatcompletion_pro" } + ], + "jiutian": [ + { + "models": ["jiutian-qianwen"], + "enabled": true, + "credentials": { + "api_key": "xxx" + }, + "server_url":"https://jiutian.10086.cn/largemodel/api/v1" + } ] } } diff --git a/go.mod b/go.mod index 069b0d9..10e66b8 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.3 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.7 // indirect diff --git a/go.sum b/go.sum index 52452fd..5420cea 100644 --- a/go.sum +++ b/go.sum @@ -70,6 +70,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= diff --git a/pkg/adapter/jiutian_openai.go b/pkg/adapter/jiutian_openai.go new file mode 100644 index 0000000..e17505b --- /dev/null +++ b/pkg/adapter/jiutian_openai.go @@ -0,0 +1,169 @@ +package adapter + +import ( + "github.com/sashabaranov/go-openai" + "go.uber.org/zap" + "simple-one-api/pkg/llm/jiutian" + "simple-one-api/pkg/mylog" + "time" +) + +// convertFinishReason 将字符串转换为OpenAI的FinishReason类型 +func convertFinishReason(reason string) openai.FinishReason { + switch reason { + case "stop": + return openai.FinishReasonStop + case "length": + return openai.FinishReasonLength + case "content_filter": + return openai.FinishReasonContentFilter + case "function_call": + return openai.FinishReasonFunctionCall + default: + return openai.FinishReasonNull + } +} + +// OpenAIRequestToJiuTianRequest 将OpenAI请求转换为九天模型请求 +func OpenAIRequestToJiuTianRequest(oaiReq *openai.ChatCompletionRequest) *jiutian.ChatCompletionRequest { + mylog.Logger.Info("Converting OpenAI request to JiuTian request", + zap.String("model", oaiReq.Model), + zap.Int("message_count", len(oaiReq.Messages)), + zap.Float32("temperature", oaiReq.Temperature), + zap.Float32("top_p", oaiReq.TopP)) + + // 获取最后一条消息作为prompt + lastMessage := oaiReq.Messages[len(oaiReq.Messages)-1].Content + + // 构建历史消息 + var history [][]string + if len(oaiReq.Messages) > 1 { + for i := 0; i < len(oaiReq.Messages)-1; i += 2 { + if i+1 < len(oaiReq.Messages) { + history = append(history, []string{ + oaiReq.Messages[i].Content, + oaiReq.Messages[i+1].Content, + }) + } + } + } + + mylog.Logger.Debug("Request conversion details", + zap.String("prompt", lastMessage), + zap.Int("history_length", len(history))) + + // 创建请求 + req := jiutian.NewChatCompletionRequest(). + WithModelID(oaiReq.Model). // 使用传入的模型ID + WithPrompt(lastMessage). + WithHistory(history). + WithStream(oaiReq.Stream) + + // 设置温度参数(如果有) + if oaiReq.Temperature > 0 { + req.WithTemperature(oaiReq.Temperature) + } + + // 设置top_p参数(如果有) + if oaiReq.TopP > 0 { + req.WithTopP(oaiReq.TopP) + } + + mylog.Logger.Debug("Created JiuTian request", + zap.String("model_id", req.ModelID), + zap.Bool("stream", req.Stream), + zap.Float32("temperature", req.Params.Temperature), + zap.Float32("top_p", req.Params.TopP)) + + return req +} + +// JiuTianResponseToOpenAIResponse 将九天模型响应转换为OpenAI响应 +func JiuTianResponseToOpenAIResponse(jiutianResp *jiutian.ChatCompletionResponse) *openai.ChatCompletionResponse { + mylog.Logger.Info("Converting JiuTian response to OpenAI response", + zap.Any("jiutian_response", map[string]interface{}{ + "usage": jiutianResp.Usage, + "response": jiutianResp.Response, + "delta": jiutianResp.Delta, + "finished": jiutianResp.Finished, + "history": jiutianResp.History, + })) + + // 创建选项 + choice := openai.ChatCompletionChoice{ + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: "assistant", + Content: jiutianResp.Response, + }, + } + + // 设置结束原因 + if jiutianResp.Finished != "" { + choice.FinishReason = convertFinishReason(jiutianResp.Finished) + } + + resp := &openai.ChatCompletionResponse{ + ID: "jiutian-" + time.Now().Format("20060102150405"), + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "Llama3.1-70B", // 使用实际的模型ID + Choices: []openai.ChatCompletionChoice{choice}, + Usage: openai.Usage{ + PromptTokens: jiutianResp.Usage.PromptTokens, + CompletionTokens: jiutianResp.Usage.CompletionTokens, + TotalTokens: jiutianResp.Usage.TotalTokens, + }, + } + + mylog.Logger.Info("Converted to OpenAI response", + zap.Any("openai_response", map[string]interface{}{ + "id": resp.ID, + "model": resp.Model, + "choices": resp.Choices, + "usage": resp.Usage, + })) + + return resp +} + +// JiuTianStreamResponseToOpenAIStreamResponse 将九天模型的流式响应转换为OpenAI流式响应 +func JiuTianStreamResponseToOpenAIStreamResponse(jiutianResp *jiutian.ChatCompletionStreamResponse) *openai.ChatCompletionStreamResponse { + mylog.Logger.Info("Converting JiuTian stream response to OpenAI stream response", + zap.Any("jiutian_stream_response", map[string]interface{}{ + "response": jiutianResp.Response, + "delta": jiutianResp.Delta, + "finished": jiutianResp.Finished, + "history": jiutianResp.History, + })) + + choice := openai.ChatCompletionStreamChoice{ + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + Content: jiutianResp.Response, + }, + } + + // 只在收到结束标记时设置结束原因 + if jiutianResp.Delta == "[EOS]" { + choice.FinishReason = convertFinishReason(jiutianResp.Finished) + } + + resp := &openai.ChatCompletionStreamResponse{ + ID: "jiutian-stream-" + time.Now().Format("20060102150405"), + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: "Llama3.1-70B", // 使用实际的模型ID + Choices: []openai.ChatCompletionStreamChoice{choice}, + } + + mylog.Logger.Info("Converted to OpenAI stream response", + zap.Any("openai_stream_response", map[string]interface{}{ + "id": resp.ID, + "model": resp.Model, + "choices": resp.Choices, + })) + + return resp +} \ No newline at end of file diff --git a/pkg/handler/openai_handler.go b/pkg/handler/openai_handler.go index b62f5a8..8fb63e2 100644 --- a/pkg/handler/openai_handler.go +++ b/pkg/handler/openai_handler.go @@ -4,9 +4,6 @@ import ( "bytes" "context" "errors" - "github.com/gin-gonic/gin" - "github.com/sashabaranov/go-openai" - "go.uber.org/zap" "io" "net/http" "simple-one-api/pkg/adapter" @@ -17,6 +14,10 @@ import ( "simple-one-api/pkg/utils" "strings" "time" + + "github.com/gin-gonic/gin" + "github.com/sashabaranov/go-openai" + "go.uber.org/zap" ) var defaultReqTimeout = 10 @@ -52,6 +53,7 @@ var serviceHandlerMap = map[string]func(*gin.Context, *OAIRequestParam) error{ "claude": OpenAI2ClaudeHandler, "agentbuilder": OpenAI2AgentBuilderHandler, "dify": OpenAI2DifyHandler, + "jiutian": OpenAI2JiuTianHandler, } func LogRequestDetails(c *gin.Context) { diff --git a/pkg/handler/openai_jiutian_handler.go b/pkg/handler/openai_jiutian_handler.go new file mode 100644 index 0000000..86bbbb3 --- /dev/null +++ b/pkg/handler/openai_jiutian_handler.go @@ -0,0 +1,186 @@ +package handler + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "simple-one-api/pkg/adapter" + "simple-one-api/pkg/config" + "simple-one-api/pkg/llm/jiutian" + "simple-one-api/pkg/mylog" + "simple-one-api/pkg/utils" + "go.uber.org/zap" +) + +// OpenAI2JiuTianHandler 处理OpenAI到九天模型的请求转换 +func OpenAI2JiuTianHandler(c *gin.Context, oaiReqParam *OAIRequestParam) error { + mylog.Logger.Info("Starting JiuTian request handling") + + oaiReq := oaiReqParam.chatCompletionReq + s := oaiReqParam.modelDetails + + // 获取API Key + apiKey, _ := utils.GetStringFromMap(oaiReqParam.creds, config.KEYNAME_API_KEY) + if apiKey == "" { + return fmt.Errorf("API key not found") + } + + // 转换请求 + jiutianReq := adapter.OpenAIRequestToJiuTianRequest(oaiReq) + + // 分别设置各个参数 + jiutianReq.WithAPIKey(apiKey) + jiutianReq.WithBaseURL(s.ServerURL) + + // 确保transport被正确设置 + if oaiReqParam.httpTransport != nil { + mylog.Logger.Debug("Setting custom transport for JiuTian request") + jiutianReq.WithTransport(oaiReqParam.httpTransport) + } else { + mylog.Logger.Debug("Using default transport for JiuTian request") + jiutianReq.WithTransport(http.DefaultTransport) + } + + // 处理流式请求 + if oaiReq.Stream { + return handleJiuTianStreamRequest(c, jiutianReq, oaiReqParam.ClientModel) + } + + // 处理非流式请求 + return handleJiuTianNonStreamRequest(c, jiutianReq, oaiReqParam.ClientModel) +} + +// handleJiuTianStreamRequest 处理流式请求 +func handleJiuTianStreamRequest(c *gin.Context, jiutianReq *jiutian.ChatCompletionRequest, clientModel string) error { + mylog.Logger.Info("Handling JiuTian stream request") + + // 记录请求头信息 + mylog.Logger.Info("Original request headers", + zap.Any("headers", c.Request.Header)) + + // 记录请求内容 + reqData, _ := json.Marshal(jiutianReq) + mylog.Logger.Info("Request content", + zap.String("request_body", string(reqData))) + + // 发送流式请求 + resp, err := jiutianReq.CreateCompletionStream() + if err != nil { + return err + } + defer resp.Body.Close() + + // 设置响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + // 记录设置的响应头 + mylog.Logger.Info("Set response headers", + zap.String("content_type", c.Writer.Header().Get("Content-Type")), + zap.String("cache_control", c.Writer.Header().Get("Cache-Control")), + zap.String("connection", c.Writer.Header().Get("Connection"))) + + // 读取并转发响应 + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + break + } + return err + } + + // 处理SSE数据 + if bytes.HasPrefix(line, []byte("data: ")) { + data := bytes.TrimPrefix(line, []byte("data: ")) + // 去掉可能存在的换行符 + data = bytes.TrimSpace(data) + + mylog.Logger.Info("Received stream data", + zap.String("raw_data", string(data)), + zap.Int("data_length", len(data))) + + var jiutianResp jiutian.ChatCompletionStreamResponse + if err := json.Unmarshal(data, &jiutianResp); err != nil { + mylog.Logger.Error("Failed to parse stream response", + zap.Error(err), + zap.String("data", string(data))) + continue + } + + // 记录解析后的九天响应 + mylog.Logger.Info("Parsed JiuTian response", + zap.Any("jiutian_response", map[string]interface{}{ + "response": jiutianResp.Response, + "delta": jiutianResp.Delta, + "finished": jiutianResp.Finished, + "history": jiutianResp.History, + })) + + // 转换为OpenAI流式响应 + streamResp := adapter.JiuTianStreamResponseToOpenAIStreamResponse(&jiutianResp) + streamResp.Model = clientModel + + // 发送响应 + responseData, _ := json.Marshal(streamResp) + mylog.Logger.Info("Sending stream response", + zap.String("response_data", string(responseData))) + + c.Writer.Write([]byte("data: ")) + c.Writer.Write(responseData) + c.Writer.Write([]byte("\n\n")) + c.Writer.Flush() + } + } + + return nil +} + +// handleJiuTianNonStreamRequest 处理非流式请求 +func handleJiuTianNonStreamRequest(c *gin.Context, jiutianReq *jiutian.ChatCompletionRequest, clientModel string) error { + mylog.Logger.Info("Handling JiuTian non-stream request") + + // 记录请求头信息 + mylog.Logger.Info("Original request headers", + zap.Any("headers", c.Request.Header)) + + // 记录请求内容 + reqData, _ := json.Marshal(jiutianReq) + mylog.Logger.Info("Request content", + zap.String("request_body", string(reqData))) + + // 发送请求 + jiutianResp, err := jiutianReq.CreateCompletion() + if err != nil { + return err + } + + // 记录九天响应 + mylog.Logger.Info("Received JiuTian response", + zap.Any("jiutian_response", map[string]interface{}{ + "usage": jiutianResp.Usage, + "response": jiutianResp.Response, + "delta": jiutianResp.Delta, + "finished": jiutianResp.Finished, + "history": jiutianResp.History, + })) + + // 转换为OpenAI响应 + chatResp := adapter.JiuTianResponseToOpenAIResponse(jiutianResp) + chatResp.Model = clientModel + + // 记录响应信息 + responseData, _ := json.Marshal(chatResp) + mylog.Logger.Info("Sending final response", + zap.String("response_data", string(responseData))) + + // 发送响应 + c.JSON(http.StatusOK, chatResp) + return nil +} \ No newline at end of file diff --git a/pkg/llm/jiutian/jiutian_request.go b/pkg/llm/jiutian/jiutian_request.go new file mode 100644 index 0000000..df238cf --- /dev/null +++ b/pkg/llm/jiutian/jiutian_request.go @@ -0,0 +1,396 @@ +package jiutian + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/golang-jwt/jwt/v5" + "go.uber.org/zap" + "io" + "net/http" + "simple-one-api/pkg/mylog" + "strings" + "time" +) + +const ( + // DefaultBaseURL 默认的九天模型API基础URL + DefaultBaseURL = "https://jiutian.10086.cn/largemodel/api/v1" + // DefaultTimeout 默认超时时间 + DefaultTimeout = 3 * time.Minute +) + +// Message 九天模型的消息结构 +type Message struct { + Role string `json:"role"` // 角色:system/user/assistant + Content string `json:"content"` // 消息内容 +} + +// ChatCompletionRequest 九天模型的对话请求结构 +type ChatCompletionRequest struct { + ModelID string `json:"modelId"` // 模型ID + Prompt string `json:"prompt"` // 当前问题 + Params *Params `json:"params"` // 参数配置 + History [][]string `json:"history"` // 历史对话记录 + Stream bool `json:"stream"` // 是否使用流式响应 + apiKey string // API密钥 + baseURL string // API基础URL + transport http.RoundTripper // HTTP传输层 +} + +// Params 模型参数配置 +type Params struct { + Temperature float32 `json:"temperature"` // 温度参数 + TopP float32 `json:"top_p"` // 核采样参数 +} + +// NewChatCompletionRequest 创建新的对话请求 +func NewChatCompletionRequest() *ChatCompletionRequest { + return &ChatCompletionRequest{ + ModelID: "Llama3.1-70B", // 默认模型 + Params: &Params{ + Temperature: 0.7, // 默认温度 + TopP: 0.95, // 默认top_p + }, + History: make([][]string, 0), + Stream: false, + baseURL: DefaultBaseURL, + } +} + +// WithModelID 设置模型ID +func (r *ChatCompletionRequest) WithModelID(modelID string) *ChatCompletionRequest { + r.ModelID = modelID + return r +} + +// WithPrompt 设置当前问题 +func (r *ChatCompletionRequest) WithPrompt(prompt string) *ChatCompletionRequest { + r.Prompt = prompt + return r +} + +// WithHistory 设置历史对话记录 +func (r *ChatCompletionRequest) WithHistory(history [][]string) *ChatCompletionRequest { + r.History = history + return r +} + +// WithStream 设置是否使用流式响应 +func (r *ChatCompletionRequest) WithStream(stream bool) *ChatCompletionRequest { + r.Stream = stream + return r +} + +// WithTemperature 设置温度参数 +func (r *ChatCompletionRequest) WithTemperature(temperature float32) *ChatCompletionRequest { + r.Params.Temperature = temperature + return r +} + +// WithTopP 设置top_p参数 +func (r *ChatCompletionRequest) WithTopP(topP float32) *ChatCompletionRequest { + r.Params.TopP = topP + return r +} + +// WithAPIKey 设置API密钥 +func (r *ChatCompletionRequest) WithAPIKey(apiKey string) *ChatCompletionRequest { + r.apiKey = apiKey + return r +} + +// WithBaseURL 设置API基础URL +func (r *ChatCompletionRequest) WithBaseURL(baseURL string) *ChatCompletionRequest { + if baseURL != "" { + r.baseURL = baseURL + } + return r +} + +// WithTransport 设置HTTP传输层 +func (r *ChatCompletionRequest) WithTransport(transport http.RoundTripper) *ChatCompletionRequest { + r.transport = transport + return r +} + +// validate 验证请求参数 +func (r *ChatCompletionRequest) validate() error { + if r.ModelID == "" { + return errors.New("modelId is required") + } + if r.Prompt == "" { + return errors.New("prompt is required") + } + if r.Params == nil { + return errors.New("params is required") + } + if r.Params.Temperature < 0 || r.Params.Temperature > 2 { + return errors.New("temperature must be between 0 and 2") + } + if r.Params.TopP < 0 || r.Params.TopP > 1 { + return errors.New("top_p must be between 0 and 1") + } + if r.apiKey == "" { + return errors.New("API key is required") + } + return nil +} + +// generateToken 生成JWT token +func (r *ChatCompletionRequest) generateToken() (string, error) { + mylog.Logger.Info("Generating JiuTian JWT token") + + // 分割API Key + parts := strings.Split(r.apiKey, ".") + if len(parts) != 2 { + mylog.Logger.Error("Invalid API key format", zap.String("api_key", r.apiKey)) + return "", errors.New("invalid API key format") + } + id, secret := parts[0], parts[1] + mylog.Logger.Debug("API key parsed", zap.String("id", id)) + + // 创建token + now := time.Now().Unix() + claims := jwt.MapClaims{ + "api_key": id, + "exp": now + 3600, // 1小时有效期 + "timestamp": now, + } + + mylog.Logger.Debug("Creating JWT claims", + zap.String("api_key", id), + zap.Int64("exp", now+3600), + zap.Int64("timestamp", now)) + + // 设置header + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["alg"] = "HS256" + token.Header["typ"] = "JWT" + token.Header["sign_type"] = "SIGN" + + mylog.Logger.Debug("JWT header set", + zap.String("alg", "HS256"), + zap.String("typ", "JWT"), + zap.String("sign_type", "SIGN")) + + // 签名 + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + mylog.Logger.Error("Failed to sign token", + zap.Error(err), + zap.String("id", id)) + return "", fmt.Errorf("failed to sign token: %v", err) + } + + mylog.Logger.Info("JWT token generated successfully", + zap.String("token", tokenString), + zap.Int64("expires_at", now+3600)) + + return tokenString, nil +} + +// CreateCompletion 发送非流式请求 +func (r *ChatCompletionRequest) CreateCompletion() (*ChatCompletionResponse, error) { + mylog.Logger.Info("Creating JiuTian chat completion") + + // 验证请求参数 + if err := r.validate(); err != nil { + return nil, err + } + + // 生成token + token, err := r.generateToken() + if err != nil { + mylog.Logger.Error("Failed to generate token", zap.Error(err)) + return nil, err + } + + // 准备请求URL + url := fmt.Sprintf("%s/completions", r.baseURL) + mylog.Logger.Info("Accessing JiuTian API", zap.String("full_url", url)) + + // 准备请求数据 + jsonData, err := json.Marshal(r) + if err != nil { + return nil, err + } + + // 创建HTTP客户端 + client := &http.Client{ + Timeout: DefaultTimeout, + } + + // 设置Transport + if r.transport != nil { + client.Transport = r.transport + } else { + client.Transport = http.DefaultTransport + } + + // 创建HTTP请求 + httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + + // 设置请求头 + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // 记录完整的请求信息 + mylog.Logger.Info("Request details", + zap.String("url", url), + zap.String("method", httpReq.Method), + zap.String("content_type", httpReq.Header.Get("Content-Type")), + zap.String("authorization", "Bearer "+token[:10]+"..."), // 只显示token的前10个字符 + zap.String("request_body", string(jsonData))) + + // 发送请求 + resp, err := client.Do(httpReq) + if err != nil { + mylog.Logger.Error("Failed to send request", + zap.Error(err), + zap.String("url", url)) + return nil, err + } + defer resp.Body.Close() + + // 记录响应头信息 + mylog.Logger.Info("Response headers", + zap.Int("status_code", resp.StatusCode), + zap.String("content_type", resp.Header.Get("Content-Type")), + zap.String("content_length", resp.Header.Get("Content-Length")), + zap.Any("all_headers", resp.Header)) + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK { + mylog.Logger.Error("API request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("body", string(body))) + return nil, fmt.Errorf("API request failed with status code: %d, body: %s", resp.StatusCode, string(body)) + } + + // 处理响应数据 + responseStr := string(body) + mylog.Logger.Debug("Raw response", zap.String("body", responseStr)) + + // 如果响应以 "data:" 开头,需要去掉这个前缀 + if strings.HasPrefix(responseStr, "data:") { + responseStr = strings.TrimPrefix(responseStr, "data:") + // 去掉可能存在的换行符 + responseStr = strings.TrimSpace(responseStr) + } + + // 解析响应 + var response ChatCompletionResponse + if err := json.Unmarshal([]byte(responseStr), &response); err != nil { + mylog.Logger.Error("Failed to parse response", + zap.Error(err), + zap.String("body", responseStr)) + return nil, err + } + + mylog.Logger.Debug("Received response from JiuTian API", + zap.Any("usage", response.Usage), + zap.String("response", response.Response), + zap.String("finished", response.Finished)) + + return &response, nil +} + +// CreateCompletionStream 发送流式请求 +func (r *ChatCompletionRequest) CreateCompletionStream() (*http.Response, error) { + mylog.Logger.Info("Creating JiuTian stream chat completion") + + // 验证请求参数 + if err := r.validate(); err != nil { + return nil, err + } + + // 生成token + token, err := r.generateToken() + if err != nil { + mylog.Logger.Error("Failed to generate token", zap.Error(err)) + return nil, err + } + + // 准备请求URL + url := fmt.Sprintf("%s/completions", r.baseURL) + mylog.Logger.Info("Accessing JiuTian API Stream", zap.String("full_url", url)) + + // 准备请求数据 + jsonData, err := json.Marshal(r) + if err != nil { + return nil, err + } + + // 创建HTTP客户端 + client := &http.Client{ + Timeout: DefaultTimeout, + } + + // 设置Transport + if r.transport != nil { + client.Transport = r.transport + } else { + client.Transport = http.DefaultTransport + } + + // 创建HTTP请求 + httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + + // 设置请求头 + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + httpReq.Header.Set("Connection", "keep-alive") + + // 记录完整的请求信息 + mylog.Logger.Info("Stream request details", + zap.String("url", url), + zap.String("method", httpReq.Method), + zap.String("content_type", httpReq.Header.Get("Content-Type")), + zap.String("accept", httpReq.Header.Get("Accept")), + zap.String("cache_control", httpReq.Header.Get("Cache-Control")), + zap.String("connection", httpReq.Header.Get("Connection")), + zap.String("authorization", "Bearer "+token[:10]+"..."), // 只显示token的前10个字符 + zap.String("request_body", string(jsonData))) + + // 发送请求 + resp, err := client.Do(httpReq) + if err != nil { + mylog.Logger.Error("Failed to send stream request", + zap.Error(err), + zap.String("url", url)) + return nil, err + } + + // 记录响应头信息 + mylog.Logger.Info("Stream response headers", + zap.Int("status_code", resp.StatusCode), + zap.String("content_type", resp.Header.Get("Content-Type")), + zap.String("transfer_encoding", resp.Header.Get("Transfer-Encoding")), + zap.Any("all_headers", resp.Header)) + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + mylog.Logger.Error("Stream API request failed", + zap.Int("status_code", resp.StatusCode)) + return nil, fmt.Errorf("API request failed with status code: %d", resp.StatusCode) + } + + return resp, nil +} \ No newline at end of file diff --git a/pkg/llm/jiutian/jiutian_response.go b/pkg/llm/jiutian/jiutian_response.go new file mode 100644 index 0000000..cb2f4db --- /dev/null +++ b/pkg/llm/jiutian/jiutian_response.go @@ -0,0 +1,25 @@ +package jiutian + +// Usage 使用统计 +type Usage struct { + PromptTokens int `json:"prompt_tokens"` // 输入的token数量 + CompletionTokens int `json:"completion_tokens"` // 生成的token数量 + TotalTokens int `json:"total_tokens"` // 总token数量 +} + +// ChatCompletionResponse 九天模型的对话响应结构 +type ChatCompletionResponse struct { + Usage Usage `json:"Usage"` // 使用统计 + Response string `json:"response"` // 模型回答 + Delta string `json:"delta"` // 结束标记 + Finished string `json:"finished"` // 结束原因 + History [][]string `json:"history"` // 历史对话记录 +} + +// ChatCompletionStreamResponse 九天模型的流式响应结构 +type ChatCompletionStreamResponse struct { + Response string `json:"response"` // 当前生成的内容 + Delta string `json:"delta"` // 结束标记 + Finished string `json:"finished"` // 结束原因 + History [][]string `json:"history"` // 历史对话记录 +} \ No newline at end of file