Skip to content

Commit 916ab57

Browse files
committed
refactor(daemon): 更新日志文件权限和改进用户目录获取逻辑
- 修改日志文件的权限设置,从0644更改为0600,以提高安全性 - 在获取用户主目录时添加错误处理,确保在无法获取时使用当前目录 - 更新停止标记文件的权限设置,确保安全性 - 优化证书部署和更新逻辑,添加域名不能为空的检查
1 parent c222a0a commit 916ab57

File tree

8 files changed

+123
-45
lines changed

8 files changed

+123
-45
lines changed

cmd/daemon.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func runSupervisor() {
9191
return
9292
}
9393

94-
supervisorLogFile, err := os.OpenFile(GetLogFile(), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
94+
supervisorLogFile, err := os.OpenFile(GetLogFile(), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
9595
if err != nil {
9696
fmt.Printf("打开日志文件失败: %v\n", err)
9797
return
@@ -120,7 +120,7 @@ func runSupervisor() {
120120

121121
cmd := exec.Command(execPath, "start", "-c", ConfigFile)
122122

123-
logFile, err := os.OpenFile(GetLogFile(), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
123+
logFile, err := os.OpenFile(GetLogFile(), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
124124
if err != nil {
125125
time.Sleep(restartDelay)
126126
continue
@@ -138,7 +138,7 @@ func runSupervisor() {
138138
}
139139

140140
pidFile := GetPIDFile()
141-
if err := os.WriteFile(pidFile, []byte(strconv.Itoa(cmd.Process.Pid)), 0644); err != nil {
141+
if err := os.WriteFile(pidFile, []byte(strconv.Itoa(cmd.Process.Pid)), 0600); err != nil {
142142
cmd.Process.Kill()
143143
logFile.Close()
144144
time.Sleep(restartDelay)
@@ -183,7 +183,11 @@ func runSupervisor() {
183183

184184
// shouldStopSupervisor 检查是否应该停止监控器
185185
func shouldStopSupervisor() bool {
186-
homeDir, _ := os.UserHomeDir()
186+
homeDir, err := os.UserHomeDir()
187+
if err != nil {
188+
// 如果无法获取用户主目录,使用当前目录
189+
homeDir = "."
190+
}
187191
stopMarker := filepath.Join(homeDir, ".cert-deploy-stop")
188192
if _, err := os.Stat(stopMarker); err == nil {
189193
os.Remove(stopMarker)
@@ -194,9 +198,15 @@ func shouldStopSupervisor() bool {
194198

195199
// StopDaemon 停止守护进程
196200
func StopDaemon() error {
197-
homeDir, _ := os.UserHomeDir()
201+
homeDir, err := os.UserHomeDir()
202+
if err != nil {
203+
// 如果无法获取用户主目录,使用当前目录
204+
homeDir = "."
205+
}
198206
stopMarker := filepath.Join(homeDir, ".cert-deploy-stop")
199-
os.WriteFile(stopMarker, []byte("stop"), 0644)
207+
if err := os.WriteFile(stopMarker, []byte("stop"), 0600); err != nil {
208+
return fmt.Errorf("创建停止标记失败: %w", err)
209+
}
200210

201211
pidFile := GetPIDFile()
202212
data, err := os.ReadFile(pidFile)

cmd/root.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ func CreateRootCmd() *cobra.Command {
4343

4444
// GetPIDFile 获取PID文件路径
4545
func GetPIDFile() string {
46-
homeDir, _ := os.UserHomeDir()
46+
homeDir, err := os.UserHomeDir()
47+
if err != nil {
48+
// 如果无法获取用户主目录,使用当前目录
49+
homeDir = "."
50+
}
4751
return filepath.Join(homeDir, ".cert-deploy.pid")
4852
}
4953

internal/client/cert_deploy.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,17 @@ func (cd *CertDeployer) extractZip(zipFile, extractDir string) error {
128128
func (cd *CertDeployer) extractZipFile(file *zip.File, extractDir string) error {
129129
// 使用 filepath.Rel 安全地检查路径
130130
targetPath := filepath.Join(extractDir, file.Name)
131-
rel, err := filepath.Rel(extractDir, targetPath)
132-
if err != nil || strings.HasPrefix(rel, "..") {
131+
132+
// 清理路径并检查符号链接
133+
cleanTarget := filepath.Clean(targetPath)
134+
rel, err := filepath.Rel(extractDir, cleanTarget)
135+
if err != nil || strings.HasPrefix(rel, "..") || strings.Contains(rel, ".."+string(filepath.Separator)) {
133136
return fmt.Errorf("不安全的文件路径: %s", file.Name)
134137
}
135138

139+
// 使用清理后的路径
140+
targetPath = cleanTarget
141+
136142
// 创建目录
137143
if file.FileInfo().IsDir() {
138144
return os.MkdirAll(targetPath, file.FileInfo().Mode())

internal/client/client.go

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/url"
1010
"os"
11+
"sync"
1112
"sync/atomic"
1213
"time"
1314

@@ -35,7 +36,9 @@ type Client struct {
3536
connectClient deployPBconnect.DeployServiceClient
3637
ctx context.Context
3738
accessKey string
38-
lastDisconnectLogged atomic.Bool // 记录是否已打印断开连接日志
39+
lastDisconnectLogged atomic.Bool // 记录是否已打印断开连接日志
40+
systemInfo *system.SystemInfo // 缓存的系统信息
41+
systemInfoOnce sync.Once // 确保系统信息只获取一次
3942
}
4043

4144
func NewClient(ctx context.Context) (*Client, error) {
@@ -48,15 +51,27 @@ func NewClient(ctx context.Context) (*Client, error) {
4851
}
4952

5053
// 配置 HTTP 客户端
51-
httpClient := &http.Client{}
54+
httpClient := &http.Client{
55+
Timeout: 30 * time.Second,
56+
}
5257
if cfg.Server.Env == "local" {
5358
p := new(http.Protocols)
5459
p.SetUnencryptedHTTP2(true)
5560
httpClient = &http.Client{
61+
Timeout: 30 * time.Second,
5662
Transport: &http.Transport{
57-
Protocols: p,
63+
Protocols: p,
64+
MaxIdleConns: 100,
65+
MaxIdleConnsPerHost: 10,
66+
IdleConnTimeout: 90 * time.Second,
5867
},
5968
}
69+
} else {
70+
httpClient.Transport = &http.Transport{
71+
MaxIdleConns: 100,
72+
MaxIdleConnsPerHost: 10,
73+
IdleConnTimeout: 90 * time.Second,
74+
}
6075
}
6176

6277
client := &Client{
@@ -75,6 +90,15 @@ func NewClient(ctx context.Context) (*Client, error) {
7590
return client, nil
7691
}
7792

93+
// getSystemInfo 获取系统信息(带缓存)
94+
func (c *Client) getSystemInfo() (*system.SystemInfo, error) {
95+
var err error
96+
c.systemInfoOnce.Do(func() {
97+
c.systemInfo, err = system.GetSystemInfo()
98+
})
99+
return c.systemInfo, err
100+
}
101+
78102
// StartConnectNotify 启动连接通知
79103
func (c *Client) StartConnectNotify() {
80104
reconnectDelay := time.Second
@@ -110,10 +134,10 @@ func (c *Client) StartConnectNotify() {
110134
consecutiveFailures = 0
111135
reconnectDelay = time.Second
112136

113-
// 获取系统信息
114-
systemInfo, err := system.GetSystemInfo()
137+
// 获取系统信息(使用缓存)
138+
systemInfo, err := c.getSystemInfo()
115139
if err != nil {
116-
logger.Error("获取系统信息失败: %v", err)
140+
logger.Error("获取系统信息失败", "error", err)
117141
stream.CloseRequest()
118142
time.Sleep(reconnectDelay)
119143
continue
@@ -138,31 +162,32 @@ func (c *Client) StartConnectNotify() {
138162

139163
// 注册客户端
140164
if err := stream.Send(registerReq); err != nil {
141-
// logger.Error("注册失败", "error", err)
142165
stream.CloseRequest()
143166
time.Sleep(reconnectDelay)
144167
continue
145168
}
146169

147-
// 处理消息流
148-
err = c.handleNotifyStream(stream)
149-
150-
// 流断开
151-
if err != nil {
152-
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
153-
return
154-
}
155-
156-
isConnected.Store(false)
157-
c.lastDisconnectLogged.Store(true)
170+
// 流断开,先检查主 context 是否被取消(而不是检查错误类型)
171+
// 因为错误链中可能包含 context.Canceled,但实际是连接断开导致的
172+
select {
173+
case <-c.ctx.Done():
174+
logger.Info("主 context 已取消,退出连接循环")
175+
return
176+
default:
177+
}
158178

159-
// 等待后重连
160-
time.Sleep(reconnectDelay)
161-
reconnectDelay = min(reconnectDelay*2, maxReconnectDelay)
162-
continue
179+
// 处理消息流
180+
if err := c.handleNotifyStream(stream); err != nil {
181+
// logger.Error("连接断开", "error", err)
163182
}
164183

165-
return
184+
// 标记断开连接
185+
isConnected.Store(false)
186+
c.lastDisconnectLogged.Store(true)
187+
188+
// 等待后重连
189+
time.Sleep(reconnectDelay)
190+
reconnectDelay = min(reconnectDelay*2, maxReconnectDelay)
166191
}
167192
}
168193

@@ -197,7 +222,9 @@ func (c *Client) handleNotifyStream(stream *connect.BidiStreamForClientSimple[de
197222
if !isConnected.Load() {
198223
isConnected.Store(true)
199224

225+
// 如果之前断开过连接,打印重连成功日志
200226
if c.lastDisconnectLogged.Load() {
227+
// logger.Info("重新连接成功")
201228
c.lastDisconnectLogged.Store(false)
202229
}
203230
}
@@ -249,7 +276,7 @@ func (c *Client) sendHeartbeat(ctx context.Context, stream *connect.BidiStreamFo
249276
Version: config.Version,
250277
})
251278
if err != nil {
252-
logger.Error("发送心跳失败: %v", err)
279+
// logger.Error("发送心跳失败", "error", err)
253280
return
254281
}
255282
}

internal/client/execute_busines.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ func (c *Client) executeBusines(stream *connect.BidiStreamForClientSimple[deploy
2222
cert := resp.Cert
2323
key := resp.Key
2424

25+
if domain == "" {
26+
logger.Error("域名不能为空")
27+
c.sendExecuteBusinesResponse(stream, requestId, deployPB.ExecuteBusinesRequest_REQUEST_RESULT_FAILED)
28+
return
29+
}
30+
2531
// 上传证书备注
2632
remark := domain + "_" + time.Now().Format(time.DateTime)
2733

@@ -117,6 +123,11 @@ func (c *Client) sendExecuteBusinesResponse(stream *connect.BidiStreamForClientS
117123

118124
// deployCertificate 部署证书
119125
func (c *Client) deployCertificate(domain, downloadURL string) {
126+
if domain == "" {
127+
logger.Error("域名不能为空")
128+
return
129+
}
130+
120131
deployer := NewCertDeployer(c)
121132
if err := deployer.DeployCertificate(domain, downloadURL); err != nil {
122133
logger.Error("证书部署失败", "error", err, "domain", domain)

internal/client/update_version.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func (c *Client) handleUpdate() {
1616

1717
updateInfo, err := updater.CheckUpdate(c.ctx)
1818
if err != nil {
19-
logger.Error("检查更新失败", err)
19+
logger.Error("检查更新失败", "error", err)
2020
return
2121
}
2222

@@ -27,18 +27,25 @@ func (c *Client) handleUpdate() {
2727
logger.Info("发现新版本", "current", updateInfo.CurrentVersion, "latest", updateInfo.LatestVersion)
2828

2929
if err := updater.PerformUpdate(c.ctx, updateInfo); err != nil {
30-
logger.Error("更新失败", err)
30+
logger.Error("更新失败", "error", err)
3131
return
3232
}
3333

3434
logger.Info("更新完成,重启中...")
3535

3636
// 创建更新标记文件
37-
execPath, _ := os.Executable()
37+
execPath, err := os.Executable()
38+
if err != nil {
39+
logger.Error("获取可执行文件路径失败", "error", err)
40+
return
41+
}
3842
execDir := filepath.Dir(execPath)
3943
markerFile := filepath.Join(execDir, ".cert-deploy-updated")
4044
content := fmt.Sprintf("%s\n%s\n", updateInfo.LatestVersion, time.Now().Format(time.RFC3339))
41-
os.WriteFile(markerFile, []byte(content), 0644)
45+
if err := os.WriteFile(markerFile, []byte(content), 0600); err != nil {
46+
logger.Error("创建更新标记文件失败", "error", err)
47+
return
48+
}
4249

4350
time.Sleep(1 * time.Second)
4451
os.Exit(0)

internal/system/info.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ func GetUniqueClientID(ctx context.Context) (string, error) {
8686
id := hex.EncodeToString(sum[:])
8787

8888
// 缓存结果,确保下次启动时使用相同的ID
89-
_ = writeCachedID(id)
89+
if err := writeCachedID(id); err != nil {
90+
// 缓存失败不影响主流程,仅记录错误
91+
fmt.Fprintf(os.Stderr, "警告: 缓存客户端ID失败: %v\n", err)
92+
}
9093

9194
return id, nil
9295
}

internal/system/network.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@ import (
1212
"github.com/orange-juzipi/cert-deploy/internal/config"
1313
)
1414

15+
// 共享的 HTTP Client,避免重复创建
16+
var publicIPClient = &http.Client{
17+
Timeout: 5 * time.Second,
18+
Transport: &http.Transport{
19+
MaxIdleConns: 10,
20+
IdleConnTimeout: 30 * time.Second,
21+
DisableKeepAlives: false,
22+
MaxIdleConnsPerHost: 2,
23+
},
24+
}
25+
1526
// getPublicIP 获取公网IP地址(使用并发请求优化)
1627
func getPublicIP() string {
1728
// 使用多个服务提供商,并发请求,提高成功率和速度
@@ -30,14 +41,17 @@ func getPublicIP() string {
3041

3142
var wg sync.WaitGroup
3243
for _, serviceURL := range services {
33-
wg.Go(func() {
44+
serviceURL := serviceURL // 捕获循环变量
45+
wg.Add(1)
46+
go func() {
47+
defer wg.Done()
3448
if ip := getIPFromService(ctx, serviceURL); ip != "" {
3549
select {
3650
case resultChan <- ip:
3751
case <-ctx.Done():
3852
}
3953
}
40-
})
54+
}()
4155
}
4256

4357
// 启动一个 goroutine 来关闭通道
@@ -69,10 +83,6 @@ func getPublicIP() string {
6983

7084
// getIPFromService 从指定服务获取IP地址
7185
func getIPFromService(ctx context.Context, serviceURL string) string {
72-
client := &http.Client{
73-
Timeout: 5 * time.Second,
74-
}
75-
7686
req, err := http.NewRequestWithContext(ctx, "GET", serviceURL, nil)
7787
if err != nil {
7888
return ""
@@ -81,7 +91,7 @@ func getIPFromService(ctx context.Context, serviceURL string) string {
8191
// 使用配置的版本号
8292
req.Header.Set("User-Agent", "cert-deploy-client/"+config.Version)
8393

84-
resp, err := client.Do(req)
94+
resp, err := publicIPClient.Do(req)
8595
if err != nil {
8696
return ""
8797
}

0 commit comments

Comments
 (0)