diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml
index 3d977913..9cc5cd86 100644
--- a/.github/workflows/build-release.yml
+++ b/.github/workflows/build-release.yml
@@ -38,6 +38,9 @@ jobs:
matrix:
goos: [linux, darwin, windows]
goarch: [amd64, arm64]
+ include:
+ - goos: android
+ goarch: arm64
exclude:
- goos: windows
goarch: arm64
diff --git a/api/auth.go b/api/auth.go
index 236bec3c..2a15d5a8 100644
--- a/api/auth.go
+++ b/api/auth.go
@@ -14,35 +14,63 @@ type tokenContextKey struct{}
// AuthMiddleware 返回认证中间件
func AuthMiddleware() func(http.Handler) http.Handler {
+ return TokenAuthMiddleware(func() string {
+ return config.C().API.Token
+ })
+}
+
+func TokenAuthMiddleware(tokenProvider func() string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- cfg := config.C().API
-
- // 从请求头获取 token
- authHeader := r.Header.Get("Authorization")
- if authHeader == "" {
- WriteError(w, http.StatusUnauthorized, "unauthorized", "missing authorization header")
+ if isPublicConfigWebPath(r) {
+ next.ServeHTTP(w, r)
return
}
- // 提取 Bearer token
- parts := strings.SplitN(authHeader, " ", 2)
- if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
- WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid authorization header format")
+ token := tokenProvider()
+ if token == "" {
+ next.ServeHTTP(w, r)
return
}
- token := parts[1]
+ requestToken := getRequestToken(r)
+ if requestToken == "" {
+ WriteError(w, http.StatusUnauthorized, "unauthorized", "missing authorization header")
+ return
+ }
- // 验证 token
- if subtle.ConstantTimeCompare([]byte(token), []byte(cfg.Token)) != 1 {
+ if subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) != 1 {
WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid token")
return
}
- // 将 token 添加到 context
- ctx := context.WithValue(r.Context(), tokenContextKey{}, token)
+ ctx := context.WithValue(r.Context(), tokenContextKey{}, requestToken)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
+
+func getRequestToken(r *http.Request) string {
+ authHeader := r.Header.Get("Authorization")
+ if authHeader != "" {
+ parts := strings.SplitN(authHeader, " ", 2)
+ if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
+ return ""
+ }
+ return parts[1]
+ }
+ if token := r.Header.Get("X-API-Token"); token != "" {
+ return token
+ }
+ if cookie, err := r.Cookie("saveany_api_token"); err == nil {
+ return cookie.Value
+ }
+ return r.URL.Query().Get("token")
+}
+
+func isPublicConfigWebPath(r *http.Request) bool {
+ if r.Method != http.MethodGet && r.Method != http.MethodHead {
+ return false
+ }
+ return r.URL.Path == "/config" || strings.HasPrefix(r.URL.Path, "/config/")
+}
diff --git a/api/config_editor.go b/api/config_editor.go
new file mode 100644
index 00000000..19832089
--- /dev/null
+++ b/api/config_editor.go
@@ -0,0 +1,653 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "regexp"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/charmbracelet/log"
+ "github.com/krau/SaveAny-Bot/common/logbuffer"
+ "github.com/krau/SaveAny-Bot/common/utils/netutil"
+ "github.com/krau/SaveAny-Bot/config"
+ "github.com/krau/SaveAny-Bot/database"
+ "github.com/krau/SaveAny-Bot/pkg/rule"
+)
+
+type ConfigEditor struct {
+ ctx context.Context
+ configPath string
+ autoOpenDatabase bool
+ mu sync.Mutex
+}
+
+type ConfigEditorOption func(*ConfigEditor)
+
+func WithConfigEditorAutoOpenDatabase(autoOpen bool) ConfigEditorOption {
+ return func(editor *ConfigEditor) {
+ editor.autoOpenDatabase = autoOpen
+ }
+}
+
+func RegisterConfigEditorRoutes(ctx context.Context, mux *http.ServeMux, configPath string, opts ...ConfigEditorOption) *ConfigEditor {
+ editor := &ConfigEditor{
+ ctx: ctx,
+ configPath: config.ResolveConfigFilePath(configPath),
+ }
+ for _, opt := range opts {
+ opt(editor)
+ }
+ mux.HandleFunc("/config", editor.WebHandler)
+ mux.HandleFunc("/config/", editor.WebHandler)
+ mux.HandleFunc("/api/v1/config", editor.ConfigHandler)
+ mux.HandleFunc("/api/v1/config/apply", editor.ApplyConfigHandler)
+ mux.HandleFunc("/api/v1/config/schema", editor.SchemaHandler)
+ mux.HandleFunc("/api/v1/config/logs", editor.LogsHandler)
+ mux.HandleFunc("/api/v1/config/proxy-test", editor.ProxyTestHandler)
+ mux.HandleFunc("/api/v1/config/update-check", editor.UpdateCheckHandler)
+ mux.HandleFunc("/api/v1/config/rules/apply", editor.ApplyRuleHandler)
+ mux.HandleFunc("/api/v1/config/rules", editor.RulesHandler)
+ mux.HandleFunc("/api/v1/config/rules/", editor.RuleByIDHandler)
+ return editor
+}
+
+func (e *ConfigEditor) WebHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet && r.Method != http.MethodHead {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write(configWebHTML)
+}
+
+func (e *ConfigEditor) ConfigHandler(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ file, err := config.LoadEditableConfig(e.configPath)
+ if err != nil {
+ WriteError(w, http.StatusInternalServerError, "config_load_failed", err.Error())
+ return
+ }
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "path": file.Path,
+ "exists": file.Exists,
+ "config": file.Config,
+ "message": "保存后建议重启 bot 以完整应用存储、代理和高级配置。",
+ })
+ case http.MethodPut:
+ cfg, err := decodeEditableConfigRequest(r)
+ if err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", err.Error())
+ return
+ }
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ storageChanges := e.storageChanges(cfg)
+ if err := config.SaveEditableConfig(e.configPath, cfg); err != nil {
+ WriteError(w, http.StatusBadRequest, "config_save_failed", err.Error())
+ return
+ }
+ warnings := e.applyStorageReferenceChanges(r.Context(), storageChanges)
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "message": "config saved",
+ "path": e.configPath,
+ "warnings": warnings,
+ })
+ default:
+ MethodNotAllowedHandler(w, r)
+ }
+}
+
+func (e *ConfigEditor) ApplyConfigHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ before := config.C()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ applyResult, err := e.reloadRuntimeConfig(r.Context(), before)
+ if err != nil {
+ WriteJSON(w, http.StatusAccepted, map[string]any{
+ "message": "runtime reload failed; restart the bot after fixing the config",
+ "path": e.configPath,
+ "reload_error": err.Error(),
+ "runtime": applyResult,
+ })
+ return
+ }
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "message": "config applied",
+ "path": e.configPath,
+ "runtime": applyResult,
+ })
+}
+
+func (e *ConfigEditor) LogsHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ limit := 200
+ if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" {
+ parsed, err := strconv.Atoi(raw)
+ if err != nil || parsed < 1 {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "limit must be a positive integer")
+ return
+ }
+ limit = parsed
+ }
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "lines": logbuffer.Default().Lines(limit),
+ })
+}
+
+func (e *ConfigEditor) UpdateCheckHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ resp, err := CheckUpdate()
+ if err != nil {
+ WriteError(w, http.StatusBadGateway, "update_check_failed", err.Error())
+ return
+ }
+ WriteJSON(w, http.StatusOK, resp)
+}
+
+func (e *ConfigEditor) ProxyTestHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ var req ProxyTestRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error())
+ return
+ }
+ req.URL = strings.TrimSpace(req.URL)
+ if req.URL == "" {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "proxy url is required")
+ return
+ }
+ target := strings.TrimSpace(req.Target)
+ if target == "" {
+ target = "https://api.telegram.org"
+ }
+ client, err := netutil.NewProxyHTTPClient(req.URL)
+ if err != nil {
+ WriteJSON(w, http.StatusOK, ProxyTestResponse{OK: false, Message: err.Error(), Target: target})
+ return
+ }
+ client.Timeout = 10 * time.Second
+ ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
+ defer cancel()
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil)
+ if err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", err.Error())
+ return
+ }
+ start := time.Now()
+ resp, err := client.Do(httpReq)
+ elapsed := time.Since(start).Milliseconds()
+ if err != nil {
+ WriteJSON(w, http.StatusOK, ProxyTestResponse{OK: false, MS: elapsed, Message: err.Error(), Target: target})
+ return
+ }
+ defer resp.Body.Close()
+ ok := resp.StatusCode >= 200 && resp.StatusCode < 500
+ message := resp.Status
+ if ok {
+ message = "proxy reachable"
+ }
+ WriteJSON(w, http.StatusOK, ProxyTestResponse{OK: ok, MS: elapsed, Message: message, Target: target})
+}
+
+func (e *ConfigEditor) SchemaHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ ruleTypes := make([]string, 0, len(rule.Values()))
+ for _, ruleType := range rule.Values() {
+ ruleTypes = append(ruleTypes, ruleType.String())
+ }
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "storage_types": config.StorageTypeNames(),
+ "storage_schemas": config.StorageSchemas(),
+ "rule_types": ruleTypes,
+ "rule_storage_chosen": rule.RuleStorNameChosen,
+ "rule_dir_new_album": rule.RuleDirPathNewForAlbum,
+ "database_ready": database.Ready(),
+ "config_path": e.configPath,
+ "config_reload_notice": "保存配置后建议重启 bot。",
+ })
+}
+
+func (e *ConfigEditor) RulesHandler(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ if err := e.ensureDatabase(r.Context()); err != nil {
+ WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error())
+ return
+ }
+ users, err := database.GetAllUsers(r.Context())
+ if err != nil {
+ WriteError(w, http.StatusInternalServerError, "users_load_failed", err.Error())
+ return
+ }
+ chatIDFilter, hasFilter, err := optionalChatID(r)
+ if err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", err.Error())
+ return
+ }
+ respUsers := make([]configUserResponse, 0, len(users))
+ for _, user := range users {
+ if hasFilter && user.ChatID != chatIDFilter {
+ continue
+ }
+ respUsers = append(respUsers, convertConfigUser(user))
+ }
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "users": respUsers,
+ "storages": configuredStorageNames(),
+ })
+ case http.MethodPost:
+ if err := e.ensureDatabase(r.Context()); err != nil {
+ WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error())
+ return
+ }
+ var req createRuleRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error())
+ return
+ }
+ normalized, err := normalizeCreateRuleRequest(req)
+ if err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_rule", err.Error())
+ return
+ }
+ user, err := database.GetUserByChatID(r.Context(), normalized.ChatID)
+ if err != nil {
+ WriteError(w, http.StatusNotFound, "user_not_found", fmt.Sprintf("user %d is not in database; save config and reload first", normalized.ChatID))
+ return
+ }
+ newRule := &database.Rule{
+ UserID: user.ID,
+ Type: normalized.Type,
+ Data: normalized.Data,
+ StorageName: normalized.StorageName,
+ DirPath: normalized.DirPath,
+ }
+ if err := database.CreateRule(r.Context(), newRule); err != nil {
+ WriteError(w, http.StatusInternalServerError, "rule_create_failed", err.Error())
+ return
+ }
+ WriteJSON(w, http.StatusCreated, convertConfigRule(*newRule))
+ default:
+ MethodNotAllowedHandler(w, r)
+ }
+}
+
+func (e *ConfigEditor) RuleByIDHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodDelete {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ if err := e.ensureDatabase(r.Context()); err != nil {
+ WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error())
+ return
+ }
+ id, err := ruleIDFromPath(r.URL.Path)
+ if err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", err.Error())
+ return
+ }
+ if err := database.DeleteRule(r.Context(), id); err != nil {
+ WriteError(w, http.StatusInternalServerError, "rule_delete_failed", err.Error())
+ return
+ }
+ WriteJSON(w, http.StatusOK, map[string]any{"message": "rule deleted"})
+}
+
+func (e *ConfigEditor) ApplyRuleHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPut && r.Method != http.MethodPatch {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ if err := e.ensureDatabase(r.Context()); err != nil {
+ WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error())
+ return
+ }
+ var req updateRuleModeRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error())
+ return
+ }
+ if req.ChatID == 0 {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "chat_id is required")
+ return
+ }
+ if err := database.UpdateUserApplyRule(r.Context(), req.ChatID, req.ApplyRule); err != nil {
+ WriteError(w, http.StatusInternalServerError, "rule_mode_update_failed", err.Error())
+ return
+ }
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "chat_id": req.ChatID,
+ "apply_rule": req.ApplyRule,
+ })
+}
+
+func (e *ConfigEditor) reloadRuntimeConfig(ctx context.Context, before config.Config) (RuntimeApplyResult, error) {
+ if err := config.Init(ctx, e.configPath); err != nil {
+ return RuntimeApplyResult{}, fmt.Errorf("saved config but failed to reload runtime config: %w", err)
+ }
+ result := RuntimeApplyResult{}
+ if e.autoOpenDatabase {
+ if !database.Ready() {
+ if err := database.Open(ctx); err != nil {
+ result.Warnings = append(result.Warnings, "数据库打开失败: "+err.Error())
+ }
+ }
+ }
+ result.Warnings = append(result.Warnings, e.cleanupRemovedStorageReferences(ctx, removedRuntimeStorageNames(before, config.C()))...)
+ applied := ApplyRuntimeConfig(ctx, e.configPath, before)
+ applied.Warnings = append(result.Warnings, applied.Warnings...)
+ return applied, nil
+}
+
+type storageReferenceChanges struct {
+ Renamed map[string]string
+ Removed []string
+}
+
+func (e *ConfigEditor) storageChanges(next *config.EditableConfig) storageReferenceChanges {
+ current, err := config.LoadEditableConfig(e.configPath)
+ if err != nil {
+ return storageReferenceChanges{}
+ }
+ return diffEditableStorages(current.Config.Storages, next.Storages)
+}
+
+func (e *ConfigEditor) applyStorageReferenceChanges(ctx context.Context, changes storageReferenceChanges) []string {
+ warnings := make([]string, 0)
+ if database.Ready() {
+ for oldName, newName := range changes.Renamed {
+ if err := database.RenameStorageReferences(ctx, oldName, newName); err != nil {
+ warnings = append(warnings, fmt.Sprintf("迁移存储 %s 到 %s 的用户关联失败: %v", oldName, newName, err))
+ }
+ }
+ }
+ warnings = append(warnings, e.cleanupRemovedStorageReferences(ctx, changes.Removed)...)
+ return warnings
+}
+
+func (e *ConfigEditor) cleanupRemovedStorageReferences(ctx context.Context, names []string) []string {
+ names = compactStrings(names)
+ if len(names) == 0 || !database.Ready() {
+ return nil
+ }
+ warnings := make([]string, 0)
+ for _, name := range names {
+ if err := database.ClearStorageReferences(ctx, name); err != nil {
+ warnings = append(warnings, fmt.Sprintf("清理存储 %s 的用户关联失败: %v", name, err))
+ }
+ }
+ return warnings
+}
+
+func diffEditableStorages(before, after []map[string]any) storageReferenceChanges {
+ beforeNames := editableStorageNames(before)
+ afterNames := editableStorageNames(after)
+ changes := storageReferenceChanges{
+ Renamed: make(map[string]string),
+ Removed: make([]string, 0),
+ }
+ for i := range before {
+ if i >= len(after) {
+ continue
+ }
+ oldName := strings.TrimSpace(fmt.Sprint(before[i]["name"]))
+ newName := strings.TrimSpace(fmt.Sprint(after[i]["name"]))
+ if oldName == "" || newName == "" || oldName == newName {
+ continue
+ }
+ if _, oldStillExists := afterNames[oldName]; oldStillExists {
+ continue
+ }
+ if _, newExistedBefore := beforeNames[newName]; newExistedBefore {
+ continue
+ }
+ changes.Renamed[oldName] = newName
+ }
+ for oldName := range beforeNames {
+ if _, ok := afterNames[oldName]; ok {
+ continue
+ }
+ if _, renamed := changes.Renamed[oldName]; renamed {
+ continue
+ }
+ changes.Removed = append(changes.Removed, oldName)
+ }
+ return changes
+}
+
+func removedRuntimeStorageNames(before, after config.Config) []string {
+ beforeNames := make(map[string]struct{}, len(before.Storages))
+ afterNames := make(map[string]struct{}, len(after.Storages))
+ for _, storage := range before.Storages {
+ beforeNames[storage.GetName()] = struct{}{}
+ }
+ for _, storage := range after.Storages {
+ afterNames[storage.GetName()] = struct{}{}
+ }
+ return missingStorageNames(beforeNames, afterNames)
+}
+
+func editableStorageNames(storages []map[string]any) map[string]struct{} {
+ names := make(map[string]struct{}, len(storages))
+ for _, storage := range storages {
+ name := strings.TrimSpace(fmt.Sprint(storage["name"]))
+ if name != "" {
+ names[name] = struct{}{}
+ }
+ }
+ return names
+}
+
+func missingStorageNames(before, after map[string]struct{}) []string {
+ missing := make([]string, 0)
+ for name := range before {
+ if _, ok := after[name]; !ok {
+ missing = append(missing, name)
+ }
+ }
+ return missing
+}
+
+func (e *ConfigEditor) ensureDatabase(ctx context.Context) error {
+ if database.Ready() {
+ return nil
+ }
+ if !e.autoOpenDatabase {
+ return fmt.Errorf("database is not initialized")
+ }
+ if err := config.Init(ctx, e.configPath); err != nil {
+ return fmt.Errorf("load config before opening database: %w", err)
+ }
+ return database.Open(ctx)
+}
+
+func decodeEditableConfigRequest(r *http.Request) (*config.EditableConfig, error) {
+ var raw json.RawMessage
+ if err := json.NewDecoder(r.Body).Decode(&raw); err != nil {
+ return nil, fmt.Errorf("failed to decode request body: %w", err)
+ }
+ var wrapped struct {
+ Config *config.EditableConfig `json:"config"`
+ }
+ if err := json.Unmarshal(raw, &wrapped); err == nil && wrapped.Config != nil {
+ return wrapped.Config, nil
+ }
+ var cfg config.EditableConfig
+ if err := json.Unmarshal(raw, &cfg); err != nil {
+ return nil, fmt.Errorf("failed to decode config: %w", err)
+ }
+ return &cfg, nil
+}
+
+type configUserResponse struct {
+ ID uint `json:"id"`
+ ChatID int64 `json:"chat_id"`
+ ApplyRule bool `json:"apply_rule"`
+ DefaultStorage string `json:"default_storage"`
+ DefaultDir uint `json:"default_dir"`
+ Rules []configRuleResponse `json:"rules"`
+}
+
+type configRuleResponse struct {
+ ID uint `json:"id"`
+ Type string `json:"type"`
+ Data string `json:"data"`
+ StorageName string `json:"storage_name"`
+ DirPath string `json:"dir_path"`
+}
+
+type createRuleRequest struct {
+ ChatID int64 `json:"chat_id"`
+ Type string `json:"type"`
+ Data string `json:"data"`
+ StorageName string `json:"storage_name"`
+ DirPath string `json:"dir_path"`
+}
+
+type updateRuleModeRequest struct {
+ ChatID int64 `json:"chat_id"`
+ ApplyRule bool `json:"apply_rule"`
+}
+
+func convertConfigUser(user database.User) configUserResponse {
+ rules := make([]configRuleResponse, 0, len(user.Rules))
+ for _, userRule := range user.Rules {
+ rules = append(rules, convertConfigRule(userRule))
+ }
+ return configUserResponse{
+ ID: user.ID,
+ ChatID: user.ChatID,
+ ApplyRule: user.ApplyRule,
+ DefaultStorage: user.DefaultStorage,
+ DefaultDir: user.DefaultDir,
+ Rules: rules,
+ }
+}
+
+func convertConfigRule(userRule database.Rule) configRuleResponse {
+ return configRuleResponse{
+ ID: userRule.ID,
+ Type: userRule.Type,
+ Data: userRule.Data,
+ StorageName: userRule.StorageName,
+ DirPath: userRule.DirPath,
+ }
+}
+
+func optionalChatID(r *http.Request) (int64, bool, error) {
+ raw := strings.TrimSpace(r.URL.Query().Get("chat_id"))
+ if raw == "" {
+ return 0, false, nil
+ }
+ chatID, err := strconv.ParseInt(raw, 10, 64)
+ if err != nil {
+ return 0, false, fmt.Errorf("invalid chat_id: %w", err)
+ }
+ return chatID, true, nil
+}
+
+func ruleIDFromPath(path string) (uint, error) {
+ raw := strings.TrimPrefix(path, "/api/v1/config/rules/")
+ if raw == "" || strings.Contains(raw, "/") {
+ return 0, fmt.Errorf("rule id is required")
+ }
+ id, err := strconv.ParseUint(raw, 10, 64)
+ if err != nil || id == 0 {
+ return 0, fmt.Errorf("invalid rule id")
+ }
+ return uint(id), nil
+}
+
+func normalizeCreateRuleRequest(req createRuleRequest) (createRuleRequest, error) {
+ req.Type = strings.ToUpper(strings.TrimSpace(req.Type))
+ req.Data = strings.TrimSpace(req.Data)
+ req.StorageName = strings.TrimSpace(req.StorageName)
+ req.DirPath = strings.TrimSpace(req.DirPath)
+ if req.ChatID == 0 {
+ return req, fmt.Errorf("chat_id is required")
+ }
+ if req.Type == "" {
+ return req, fmt.Errorf("type is required")
+ }
+ if req.Data == "" {
+ return req, fmt.Errorf("data is required")
+ }
+ if req.StorageName == "" {
+ return req, fmt.Errorf("storage_name is required")
+ }
+ if req.DirPath == "" {
+ return req, fmt.Errorf("dir_path is required")
+ }
+ validType := false
+ for _, value := range rule.Values() {
+ if req.Type == value.String() {
+ validType = true
+ break
+ }
+ }
+ if !validType {
+ return req, fmt.Errorf("invalid rule type: %s", req.Type)
+ }
+ switch req.Type {
+ case rule.FileNameRegex.String(), rule.MessageRegex.String():
+ if _, err := regexp.Compile(req.Data); err != nil {
+ return req, fmt.Errorf("invalid regex: %w", err)
+ }
+ case rule.IsAlbum.String():
+ if _, err := strconv.ParseBool(req.Data); err != nil {
+ return req, fmt.Errorf("IS-ALBUM data must be true or false")
+ }
+ }
+ if req.StorageName != rule.RuleStorNameChosen && !configuredStorageNameExists(req.StorageName) {
+ return req, fmt.Errorf("unknown storage_name: %s", req.StorageName)
+ }
+ return req, nil
+}
+
+func configuredStorageNames() []string {
+ cfg, err := config.LoadEditableConfig(config.ConfigFileUsed())
+ if err != nil {
+ log.Warnf("failed to load configured storage names: %v", err)
+ return nil
+ }
+ names := make([]string, 0, len(cfg.Config.Storages))
+ for _, storage := range cfg.Config.Storages {
+ name := strings.TrimSpace(fmt.Sprint(storage["name"]))
+ if name != "" {
+ names = append(names, name)
+ }
+ }
+ return names
+}
+
+func configuredStorageNameExists(name string) bool {
+ for _, configuredName := range configuredStorageNames() {
+ if configuredName == name {
+ return true
+ }
+ }
+ return false
+}
diff --git a/api/config_server.go b/api/config_server.go
new file mode 100644
index 00000000..ea25d20b
--- /dev/null
+++ b/api/config_server.go
@@ -0,0 +1,71 @@
+package api
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/charmbracelet/log"
+ "github.com/krau/SaveAny-Bot/config"
+)
+
+type ConfigWebServerOptions struct {
+ ConfigPath string
+ Host string
+ Port int
+ Token string
+}
+
+func NewConfigWebServer(ctx context.Context, opts ConfigWebServerOptions) *http.Server {
+ mux := http.NewServeMux()
+ RegisterConfigEditorRoutes(ctx, mux, opts.ConfigPath, WithConfigEditorAutoOpenDatabase(true))
+ mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
+ WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"})
+ })
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/" {
+ http.Redirect(w, r, "/config", http.StatusFound)
+ return
+ }
+ NotFoundHandler(w, r)
+ })
+
+ var handler http.Handler = mux
+ handler = TokenAuthMiddleware(func() string {
+ if opts.Token != "" {
+ return opts.Token
+ }
+ return config.C().API.Token
+ })(handler)
+ handler = loggingMiddleware(handler)
+ handler = recoveryMiddleware(handler)
+
+ return &http.Server{
+ Addr: fmt.Sprintf("%s:%d", opts.Host, opts.Port),
+ Handler: handler,
+ ReadTimeout: 30 * time.Second,
+ WriteTimeout: 30 * time.Second,
+ IdleTimeout: 120 * time.Second,
+ }
+}
+
+func StartConfigWebServer(ctx context.Context, opts ConfigWebServerOptions) (*http.Server, error) {
+ server := NewConfigWebServer(ctx, opts)
+ logger := log.FromContext(ctx).With("module", "config-web")
+ logger.Infof("Starting config web server on %s", server.Addr)
+ go func() {
+ if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ logger.Errorf("Config web server error: %v", err)
+ }
+ }()
+ go func() {
+ <-ctx.Done()
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := server.Shutdown(shutdownCtx); err != nil {
+ logger.Errorf("Config web server shutdown error: %v", err)
+ }
+ }()
+ return server, nil
+}
diff --git a/api/config_web.go b/api/config_web.go
new file mode 100644
index 00000000..e296bb7e
--- /dev/null
+++ b/api/config_web.go
@@ -0,0 +1,6 @@
+package api
+
+import _ "embed"
+
+//go:embed config_web.html
+var configWebHTML []byte
diff --git a/api/config_web.html b/api/config_web.html
new file mode 100644
index 00000000..47658163
--- /dev/null
+++ b/api/config_web.html
@@ -0,0 +1,1186 @@
+
+
+
+
+
+ SaveAny-Bot 配置
+
+
+
+
+
+
+
+
+
+
+
+ 验证
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/api/factory.go b/api/factory.go
index 4c00a6bd..51b9c84e 100644
--- a/api/factory.go
+++ b/api/factory.go
@@ -38,8 +38,8 @@ func NewTaskFactory(ctx context.Context) *TaskFactory {
// CreateTask 创建任务
func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, error) {
// 验证存储
- stor, ok := storage.Storages[req.Storage]
- if !ok {
+ stor, err := storage.GetStorageByName(f.ctx, req.Storage)
+ if err != nil {
return nil, fmt.Errorf("storage not found: %s", req.Storage)
}
@@ -66,19 +66,6 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
}
}
-func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
- taskID := task.TaskID()
- RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
-
- err := core.AddTask(f.ctx, NewExecutableWrapper(task))
- if err != nil {
- DeleteTask(taskID)
- return fmt.Errorf("failed to add task: %w", err)
- }
-
- return nil
-}
-
// createDirectLinksTask 创建直链下载任务
func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params DirectLinksParams
@@ -90,11 +77,12 @@ func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time,
return nil, fmt.Errorf("no URLs provided")
}
- task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil)
+ task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, newDirectLinksAPIProgress(taskID))
+ RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req)
- err := f.registerAndEnqueueTask(task, tasktype.TaskTypeDirectlinks, req.Storage, req.Path, req.Webhook)
- if err != nil {
- return nil, err
+ if err := core.AddTask(f.ctx, task); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
@@ -116,11 +104,12 @@ func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *C
return nil, fmt.Errorf("no URLs provided")
}
- task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil)
+ task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, newYTDLPAPIProgress(taskID))
+ RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req)
- err := f.registerAndEnqueueTask(task, tasktype.TaskTypeYtdlp, req.Storage, req.Path, req.Webhook)
- if err != nil {
- return nil, err
+ if err := core.AddTask(f.ctx, task); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
@@ -159,11 +148,12 @@ func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *C
return nil, fmt.Errorf("failed to add aria2 task: %w", err)
}
- task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil)
+ task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, newAria2APIProgress(taskID))
+ RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req)
- err = f.registerAndEnqueueTask(task, tasktype.TaskTypeAria2, req.Storage, req.Path, req.Webhook)
- if err != nil {
- return nil, err
+ if err := core.AddTask(f.ctx, task); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
@@ -204,11 +194,12 @@ func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req *
return nil, fmt.Errorf("failed to parse URL: %w", err)
}
- task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil)
+ task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, newParsedAPIProgress(taskID))
+ RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req)
- err = f.registerAndEnqueueTask(task, tasktype.TaskTypeParseditem, req.Storage, req.Path, req.Webhook)
- if err != nil {
- return nil, err
+ if err := core.AddTask(f.ctx, task); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
@@ -240,15 +231,17 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req
return nil, fmt.Errorf("no files found in provided links")
}
- var task core.Executable
-
if len(files) == 1 {
// 单个文件任务
- tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil)
+ tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, newTFileAPIProgress(taskID))
if err != nil {
return nil, fmt.Errorf("failed to create tfile task: %w", err)
}
- task = tfileTask
+ RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, tfileTask.Title(), req.Webhook, req)
+ if err := core.AddTask(f.ctx, tfileTask); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
+ }
} else {
// 批量文件任务
elems := make([]batchtfile.TaskElement, 0, len(files))
@@ -260,12 +253,12 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req
elems = append(elems, *elem)
}
- task = batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
- }
-
- err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTgfiles, req.Storage, req.Path, req.Webhook)
- if err != nil {
- return nil, err
+ task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, newBatchTFileAPIProgress(taskID), true)
+ RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req)
+ if err := core.AddTask(f.ctx, task); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
+ }
}
return &CreateTaskResponse{
@@ -298,11 +291,12 @@ func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req
}
client := telegraph.NewClient()
- task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil)
+ task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, newTelegraphAPIProgress(taskID))
+ RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req)
- err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTphpics, req.Storage, req.Path, req.Webhook)
- if err != nil {
- return nil, err
+ if err := core.AddTask(f.ctx, task); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
@@ -321,13 +315,13 @@ func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req
}
// 验证源存储和目标存储
- sourceStor, ok := storage.Storages[params.SourceStorage]
- if !ok {
+ sourceStor, err := storage.GetStorageByName(f.ctx, params.SourceStorage)
+ if err != nil {
return nil, fmt.Errorf("source storage not found: %s", params.SourceStorage)
}
- targetStor, ok := storage.Storages[params.TargetStorage]
- if !ok {
+ targetStor, err := storage.GetStorageByName(f.ctx, params.TargetStorage)
+ if err != nil {
return nil, fmt.Errorf("target storage not found: %s", params.TargetStorage)
}
@@ -360,11 +354,14 @@ func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req
elems = append(elems, *elem)
}
- task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true)
+ task := transfer.NewTransferTask(taskID, f.ctx, elems, newTransferAPIProgress(taskID), true)
+ info := RegisterTask(taskID, req.Type.String(), params.TargetStorage, params.TargetPath, task.Title(), req.Webhook, req)
+ info.SetTransferMeta(params.SourceStorage, params.SourcePath, params.TargetStorage, params.TargetPath)
+ RegisterTaskControl(taskID, task)
- err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTransfer, params.TargetStorage, params.TargetPath, req.Webhook)
- if err != nil {
- return nil, err
+ if err := core.AddTask(f.ctx, task); err != nil {
+ DeleteTask(taskID)
+ return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
diff --git a/api/handlers.go b/api/handlers.go
index 81946d02..db321d7a 100644
--- a/api/handlers.go
+++ b/api/handlers.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"strings"
+ "sync/atomic"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
@@ -117,6 +118,15 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}
+ if r.URL.Query().Get("record") == "1" || r.URL.Query().Get("record") == "true" {
+ if task.Status == TaskStatusQueued || task.Status == TaskStatusRunning {
+ _ = core.CancelTask(r.Context(), taskID)
+ }
+ DeleteTask(taskID)
+ WriteJSON(w, http.StatusOK, map[string]string{"message": "task deleted successfully"})
+ return
+ }
+
// 取消任务
if err := core.CancelTask(r.Context(), taskID); err != nil {
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
@@ -127,6 +137,106 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
WriteJSON(w, http.StatusOK, map[string]string{"message": "task cancelled successfully"})
}
+func (h *Handlers) PauseTaskHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ taskID, _ := extractTaskIDAndAction(r.URL.Path)
+ if taskID == "" {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
+ return
+ }
+ task, ok := GetTask(taskID)
+ if !ok {
+ WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
+ return
+ }
+ if task.Status == TaskStatusQueued || task.Status == TaskStatusRunning {
+ if err := core.CancelTask(r.Context(), taskID); err != nil {
+ WriteError(w, http.StatusInternalServerError, "pause_failed", "failed to pause task: "+err.Error())
+ return
+ }
+ }
+ task.UpdateStatus(TaskStatusPaused)
+ task.UpdatePhase("paused")
+ WriteJSON(w, http.StatusOK, map[string]string{"message": "task paused successfully"})
+}
+
+func (h *Handlers) RetryTaskHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ taskID, _ := extractTaskIDAndAction(r.URL.Path)
+ if taskID == "" {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
+ return
+ }
+ task, ok := GetTask(taskID)
+ if !ok {
+ WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
+ return
+ }
+ if task.Request == nil {
+ WriteError(w, http.StatusBadRequest, "retry_unavailable", "task request data is not available")
+ return
+ }
+ if task.Status != TaskStatusFailed && task.Status != TaskStatusCancelled && task.Status != TaskStatusPaused {
+ WriteError(w, http.StatusBadRequest, "retry_unavailable", "only failed, cancelled, or paused tasks can be retried")
+ return
+ }
+ resp, err := h.factory.CreateTask(cloneCreateTaskRequest(task.Request))
+ if err != nil {
+ WriteError(w, http.StatusBadRequest, "retry_failed", err.Error())
+ return
+ }
+ WriteJSON(w, http.StatusCreated, resp)
+}
+
+func (h *Handlers) UpdateTaskPathHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPut && r.Method != http.MethodPatch {
+ MethodNotAllowedHandler(w, r)
+ return
+ }
+ taskID, _ := extractTaskIDAndAction(r.URL.Path)
+ if taskID == "" {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
+ return
+ }
+ task, ok := GetTask(taskID)
+ if !ok {
+ WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
+ return
+ }
+ if task.Type != string(tasktype.TaskTypeTransfer) {
+ WriteError(w, http.StatusBadRequest, "unsupported_task", "only transfer task path can be updated")
+ return
+ }
+ var req UpdateTaskPathRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error())
+ return
+ }
+ req.Path = strings.TrimSpace(req.Path)
+ if req.Path == "" {
+ WriteError(w, http.StatusBadRequest, "invalid_request", "path is required")
+ return
+ }
+ control, ok := GetTaskControl(taskID)
+ if !ok {
+ WriteError(w, http.StatusConflict, "control_unavailable", "task runtime control is not available")
+ return
+ }
+ control.UpdateTargetPath(req.Path)
+ task.UpdateTargetPath(req.Path)
+ updateStoredTransferTargetPath(task, req.Path)
+ WriteJSON(w, http.StatusOK, map[string]any{
+ "message": "task path updated",
+ "target_path": req.Path,
+ })
+}
+
// ListStoragesHandler 列出存储处理器
func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
@@ -134,8 +244,9 @@ func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) {
return
}
- storages := make([]StorageInfo, 0, len(storage.Storages))
- for name, stor := range storage.Storages {
+ snapshot := storage.Snapshot()
+ storages := make([]StorageInfo, 0, len(snapshot))
+ for name, stor := range snapshot {
storages = append(storages, StorageInfo{
Name: name,
Type: string(stor.Type()),
@@ -177,33 +288,54 @@ func (h *Handlers) HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
// extractTaskIDFromPath 从路径中提取任务 ID
// 路径格式: /api/v1/tasks/:id
func extractTaskIDFromPath(path string) string {
+ taskID, _ := extractTaskIDAndAction(path)
+ return taskID
+}
+
+func extractTaskIDAndAction(path string) (string, string) {
parts := strings.Split(strings.Trim(path, "/"), "/")
if len(parts) < 4 {
- return ""
+ return "", ""
+ }
+ action := ""
+ if len(parts) > 4 {
+ action = parts[4]
}
- return parts[3]
+ return parts[3], action
}
// convertTaskProgressToResponse 将任务进度转换为响应格式
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
resp := TaskInfoResponse{
- TaskID: task.TaskID,
- Type: tasktype.TaskType(task.Type),
- Status: task.Status,
- Title: task.Title,
- Storage: task.Storage,
- Path: task.Path,
- Error: task.Error,
- CreatedAt: task.CreatedAt,
- UpdatedAt: task.UpdatedAt,
+ TaskID: task.TaskID,
+ Type: tasktype.TaskType(task.Type),
+ Status: task.Status,
+ Title: task.Title,
+ Storage: task.Storage,
+ Path: task.Path,
+ SourceStorage: task.SourceStorage,
+ SourcePath: task.SourcePath,
+ TargetStorage: task.TargetStorage,
+ TargetPath: task.TargetPath,
+ Phase: task.Phase,
+ Error: task.Error,
+ CreatedAt: task.CreatedAt,
+ UpdatedAt: task.UpdatedAt,
}
// 计算进度
if task.TotalBytes > 0 {
- percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
+ downloaded := atomic.LoadInt64(&task.DownloadedBytes)
+ uploaded := atomic.LoadInt64(&task.UploadedBytes)
+ current := downloaded
+ if uploaded > 0 || task.Phase == "uploading" {
+ current = uploaded
+ }
+ percent := float64(current) * 100 / float64(task.TotalBytes)
resp.Progress = &TaskProgress{
TotalBytes: task.TotalBytes,
- DownloadedBytes: task.DownloadedBytes,
+ DownloadedBytes: downloaded,
+ UploadedBytes: uploaded,
Percent: percent,
}
}
@@ -211,6 +343,22 @@ func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
return resp
}
+func updateStoredTransferTargetPath(info *TaskProgressInfo, targetPath string) {
+ if info.Request == nil || info.Request.Params == nil {
+ return
+ }
+ var params TransferParams
+ if err := json.Unmarshal(info.Request.Params, ¶ms); err != nil {
+ return
+ }
+ params.TargetPath = targetPath
+ data, err := json.Marshal(params)
+ if err != nil {
+ return
+ }
+ info.Request.Params = data
+}
+
// NotFoundHandler 404 处理器
func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, http.StatusNotFound, "not_found", "endpoint not found: "+r.URL.Path)
diff --git a/api/progress.go b/api/progress.go
index adac3712..71d103a9 100644
--- a/api/progress.go
+++ b/api/progress.go
@@ -1,6 +1,7 @@
package api
import (
+ "encoding/json"
"sync"
"sync/atomic"
"time"
@@ -14,28 +15,41 @@ type TaskProgressInfo struct {
Title string
TotalBytes int64
DownloadedBytes int64
+ UploadedBytes int64
TotalFiles int
DownloadedFiles int
Storage string
Path string
+ SourceStorage string
+ SourcePath string
+ TargetStorage string
+ TargetPath string
+ Phase string
Error string
CreatedAt time.Time
UpdatedAt time.Time
Webhook string
+ Request *CreateTaskRequest
}
// progressStore 存储所有 API 任务的进度信息
type progressStore struct {
- mu sync.RWMutex
- tasks map[string]*TaskProgressInfo
+ mu sync.RWMutex
+ tasks map[string]*TaskProgressInfo
+ controls map[string]TaskControl
}
var store = &progressStore{
- tasks: make(map[string]*TaskProgressInfo),
+ tasks: make(map[string]*TaskProgressInfo),
+ controls: make(map[string]TaskControl),
+}
+
+type TaskControl interface {
+ UpdateTargetPath(path string)
}
// RegisterTask 注册一个新的 API 任务
-func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
+func RegisterTask(taskID, taskType, storage, path, title, webhook string, reqs ...*CreateTaskRequest) *TaskProgressInfo {
info := &TaskProgressInfo{
TaskID: taskID,
Type: taskType,
@@ -43,10 +57,14 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP
Title: title,
Storage: storage,
Path: path,
+ Phase: "queued",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Webhook: webhook,
}
+ if len(reqs) > 0 && reqs[0] != nil {
+ info.Request = cloneCreateTaskRequest(reqs[0])
+ }
store.mu.Lock()
store.tasks[taskID] = info
@@ -80,6 +98,23 @@ func DeleteTask(taskID string) {
store.mu.Lock()
defer store.mu.Unlock()
delete(store.tasks, taskID)
+ delete(store.controls, taskID)
+}
+
+func RegisterTaskControl(taskID string, control TaskControl) {
+ if control == nil {
+ return
+ }
+ store.mu.Lock()
+ defer store.mu.Unlock()
+ store.controls[taskID] = control
+}
+
+func GetTaskControl(taskID string) (TaskControl, bool) {
+ store.mu.RLock()
+ defer store.mu.RUnlock()
+ control, ok := store.controls[taskID]
+ return control, ok
}
// UpdateStatus 更新任务状态
@@ -92,6 +127,44 @@ func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
func (t *TaskProgressInfo) SetError(err string) {
t.Error = err
t.Status = TaskStatusFailed
+ t.Phase = "failed"
+ t.UpdatedAt = time.Now()
+}
+
+func (t *TaskProgressInfo) SetTransferMeta(sourceStorage, sourcePath, targetStorage, targetPath string) {
+ t.SourceStorage = sourceStorage
+ t.SourcePath = sourcePath
+ t.TargetStorage = targetStorage
+ t.TargetPath = targetPath
+ t.Path = targetPath
+ t.UpdatedAt = time.Now()
+}
+
+func (t *TaskProgressInfo) UpdatePhase(phase string) {
+ t.Phase = phase
+ t.UpdatedAt = time.Now()
+}
+
+func (t *TaskProgressInfo) UpdateDownloadProgress(downloadedBytes, totalBytes int64) {
+ if totalBytes > 0 {
+ t.TotalBytes = totalBytes
+ }
+ atomic.StoreInt64(&t.DownloadedBytes, downloadedBytes)
+ t.UpdatedAt = time.Now()
+}
+
+func (t *TaskProgressInfo) UpdateUploadProgress(uploadedBytes, totalBytes int64) {
+ if totalBytes > 0 {
+ t.TotalBytes = totalBytes
+ }
+ atomic.StoreInt64(&t.UploadedBytes, uploadedBytes)
+ t.UpdatedAt = time.Now()
+}
+
+func (t *TaskProgressInfo) UpdateTargetPath(path string) {
+ t.TargetPath = path
+ t.Path = path
+ t.UploadedBytes = 0
t.UpdatedAt = time.Now()
}
@@ -111,6 +184,7 @@ func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {
p.info.Status = TaskStatusRunning
p.info.TotalBytes = totalBytes
p.info.TotalFiles = totalFiles
+ p.info.Phase = "running"
p.info.UpdatedAt = time.Now()
}
@@ -126,8 +200,10 @@ func (p *ProgressTracker) OnDone(err error) {
if err != nil {
p.info.Status = TaskStatusFailed
p.info.Error = err.Error()
+ p.info.Phase = "failed"
} else {
p.info.Status = TaskStatusCompleted
+ p.info.Phase = "completed"
}
p.info.UpdatedAt = time.Now()
}
@@ -148,3 +224,14 @@ func (p *ProgressTracker) UpdateProgressFiles(files int) {
p.info.DownloadedFiles = files
p.info.UpdatedAt = time.Now()
}
+
+func cloneCreateTaskRequest(req *CreateTaskRequest) *CreateTaskRequest {
+ if req == nil {
+ return nil
+ }
+ cloned := *req
+ if req.Params != nil {
+ cloned.Params = append(json.RawMessage(nil), req.Params...)
+ }
+ return &cloned
+}
diff --git a/api/runtime.go b/api/runtime.go
new file mode 100644
index 00000000..8ce7c285
--- /dev/null
+++ b/api/runtime.go
@@ -0,0 +1,165 @@
+package api
+
+import (
+ "context"
+ "fmt"
+ "runtime"
+ "slices"
+ "strings"
+
+ "github.com/blang/semver"
+ "github.com/charmbracelet/log"
+ "github.com/krau/SaveAny-Bot/common/i18n"
+ "github.com/krau/SaveAny-Bot/config"
+ "github.com/krau/SaveAny-Bot/database"
+ "github.com/krau/SaveAny-Bot/storage"
+ "github.com/unvgo/ghselfupdate"
+)
+
+type RuntimeApplyResult struct {
+ Applied []string `json:"applied"`
+ RestartRequired []string `json:"restart_required"`
+ Warnings []string `json:"warnings,omitempty"`
+}
+
+type UpdateCheckResponse struct {
+ CurrentVersion string `json:"current_version"`
+ LatestVersion string `json:"latest_version,omitempty"`
+ LatestName string `json:"latest_name,omitempty"`
+ Found bool `json:"found"`
+ HasUpdate bool `json:"has_update"`
+ ReleaseURL string `json:"release_url,omitempty"`
+ ReleaseNotes string `json:"release_notes,omitempty"`
+ Platform string `json:"platform"`
+ Message string `json:"message,omitempty"`
+}
+
+func ApplyRuntimeConfig(ctx context.Context, configPath string, before config.Config) RuntimeApplyResult {
+ result := RuntimeApplyResult{
+ Applied: []string{"配置文件"},
+ }
+ after := config.C()
+
+ if err := applyLogLevel(after.Log.Level); err != nil {
+ result.Warnings = append(result.Warnings, err.Error())
+ } else {
+ result.Applied = append(result.Applied, "日志级别")
+ }
+
+ i18n.Init(after.Lang)
+ result.Applied = append(result.Applied, "语言")
+
+ if database.Ready() {
+ if dbPathChanged(before, after) {
+ result.RestartRequired = append(result.RestartRequired, "数据库路径")
+ } else if err := database.SyncUsers(ctx); err != nil {
+ result.Warnings = append(result.Warnings, "同步用户失败: "+err.Error())
+ } else {
+ result.Applied = append(result.Applied, "用户配置")
+ }
+ }
+
+ if err := storage.ReloadStorages(ctx); err != nil {
+ result.Warnings = append(result.Warnings, "部分存储加载失败: "+err.Error())
+ }
+ result.Applied = append(result.Applied, "存储配置")
+
+ result.Applied = append(result.Applied, "全局 HTTP 代理", "API Token")
+ result.RestartRequired = append(result.RestartRequired, restartRequiredChanges(before, after)...)
+ result.RestartRequired = compactStrings(result.RestartRequired)
+ result.Applied = compactStrings(result.Applied)
+ return result
+}
+
+func CheckUpdate() (*UpdateCheckResponse, error) {
+ resp := &UpdateCheckResponse{
+ CurrentVersion: config.Version,
+ Platform: runtime.GOOS + "/" + runtime.GOARCH,
+ }
+ current, err := semver.ParseTolerant(strings.TrimPrefix(config.Version, "v"))
+ if err != nil {
+ resp.Message = "当前版本不是发布版本,无法准确比较。"
+ }
+
+ latest, found, err := ghselfupdate.DetectLatest(config.GitRepo)
+ if err != nil {
+ return nil, err
+ }
+ resp.Found = found
+ if !found {
+ resp.Message = "没有找到 release。"
+ return resp, nil
+ }
+
+ resp.LatestVersion = latest.Version.String()
+ resp.LatestName = latest.Name
+ resp.ReleaseNotes = latest.ReleaseNotes
+ resp.ReleaseURL = latest.URL
+ if err == nil {
+ resp.HasUpdate = latest.Version.GT(current)
+ if latest.Version.Equals(current) || latest.Version.LT(current) {
+ resp.Message = "当前已经是最新版本。"
+ }
+ }
+ return resp, nil
+}
+
+func applyLogLevel(level string) error {
+ parsed, err := log.ParseLevel(strings.TrimSpace(level))
+ if err != nil {
+ return fmt.Errorf("日志级别无效: %w", err)
+ }
+ log.Default().SetLevel(parsed)
+ return nil
+}
+
+func restartRequiredChanges(before, after config.Config) []string {
+ var fields []string
+ if before.Lang == "" && before.DB.Path == "" {
+ return fields
+ }
+ if before.Workers != 0 && before.Workers != after.Workers {
+ fields = append(fields, "Workers 并发数")
+ }
+ if before.API.Enable != after.API.Enable || before.API.Host != after.API.Host || before.API.Port != after.API.Port {
+ fields = append(fields, "HTTP API 监听地址/端口")
+ }
+ if before.Telegram.Token != after.Telegram.Token ||
+ before.Telegram.AppID != after.Telegram.AppID ||
+ before.Telegram.AppHash != after.Telegram.AppHash ||
+ before.Telegram.Proxy != after.Telegram.Proxy ||
+ before.Telegram.Userbot != after.Telegram.Userbot {
+ fields = append(fields, "Telegram Bot/Userbot 连接")
+ }
+ if before.Parser.PluginEnable != after.Parser.PluginEnable || !slices.Equal(before.Parser.PluginDirs, after.Parser.PluginDirs) {
+ fields = append(fields, "Parser 插件加载")
+ }
+ if before.Cache != after.Cache {
+ fields = append(fields, "缓存参数")
+ }
+ if before.Hook != after.Hook {
+ fields = append(fields, "任务 Hook")
+ }
+ return fields
+}
+
+func dbPathChanged(before, after config.Config) bool {
+ return before.DB.Path != "" && before.DB.Path != after.DB.Path
+}
+
+func compactStrings(values []string) []string {
+ seen := make(map[string]struct{}, len(values))
+ out := make([]string, 0, len(values))
+ for _, value := range values {
+ value = strings.TrimSpace(value)
+ if value == "" {
+ continue
+ }
+ if _, ok := seen[value]; ok {
+ continue
+ }
+ seen[value] = struct{}{}
+ out = append(out, value)
+ }
+ return out
+}
diff --git a/api/server.go b/api/server.go
index be8b6603..1fd7e6fc 100644
--- a/api/server.go
+++ b/api/server.go
@@ -41,10 +41,26 @@ func NewServer(ctx context.Context) *Server {
}
})
mux.HandleFunc("/api/v1/tasks/", func(w http.ResponseWriter, r *http.Request) {
+ _, action := extractTaskIDAndAction(r.URL.Path)
+ switch action {
+ case "pause":
+ handlers.PauseTaskHandler(w, r)
+ return
+ case "retry":
+ handlers.RetryTaskHandler(w, r)
+ return
+ case "path":
+ handlers.UpdateTaskPathHandler(w, r)
+ return
+ }
// 根据方法和路径分发
switch r.Method {
case http.MethodGet:
- handlers.GetTaskHandler(w, r)
+ if r.URL.Path == "/api/v1/tasks" {
+ handlers.ListTasksHandler(w, r)
+ } else {
+ handlers.GetTaskHandler(w, r)
+ }
case http.MethodDelete:
handlers.CancelTaskHandler(w, r)
default:
@@ -53,6 +69,7 @@ func NewServer(ctx context.Context) *Server {
})
mux.HandleFunc("/api/v1/storages", handlers.ListStoragesHandler)
mux.HandleFunc("/api/v1/task-types", handlers.GetTaskTypesHandler)
+ RegisterConfigEditorRoutes(ctx, mux, config.ConfigFileUsed())
// 404 处理
mux.HandleFunc("/", NotFoundHandler)
diff --git a/api/task_progress_adapters.go b/api/task_progress_adapters.go
new file mode 100644
index 00000000..8f19ec66
--- /dev/null
+++ b/api/task_progress_adapters.go
@@ -0,0 +1,259 @@
+package api
+
+import (
+ "context"
+ "errors"
+ "strconv"
+ "strings"
+
+ "github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
+ "github.com/krau/SaveAny-Bot/core/tasks/batchtfile"
+ "github.com/krau/SaveAny-Bot/core/tasks/directlinks"
+ "github.com/krau/SaveAny-Bot/core/tasks/parsed"
+ tphtask "github.com/krau/SaveAny-Bot/core/tasks/telegraph"
+ "github.com/krau/SaveAny-Bot/core/tasks/tfile"
+ "github.com/krau/SaveAny-Bot/core/tasks/transfer"
+ "github.com/krau/SaveAny-Bot/core/tasks/ytdlp"
+ "github.com/krau/SaveAny-Bot/pkg/aria2"
+)
+
+type directLinksAPIProgress struct{ taskID string }
+
+func newDirectLinksAPIProgress(taskID string) directlinks.ProgressTracker {
+ return &directLinksAPIProgress{taskID: taskID}
+}
+
+func (p *directLinksAPIProgress) OnStart(ctx context.Context, info directlinks.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdateStatus(TaskStatusRunning)
+ task.UpdatePhase("downloading")
+ task.TotalFiles = info.TotalFiles()
+ }
+}
+
+func (p *directLinksAPIProgress) OnProgress(ctx context.Context, info directlinks.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdatePhase("downloading")
+ task.UpdateDownloadProgress(info.DownloadedBytes(), info.TotalBytes())
+ task.TotalFiles = info.TotalFiles()
+ }
+}
+
+func (p *directLinksAPIProgress) OnDone(ctx context.Context, info directlinks.TaskInfo, err error) {
+ finishAPITask(p.taskID, err)
+}
+
+type ytdlpAPIProgress struct{ taskID string }
+
+func newYTDLPAPIProgress(taskID string) ytdlp.ProgressTracker {
+ return &ytdlpAPIProgress{taskID: taskID}
+}
+
+func (p *ytdlpAPIProgress) OnStart(ctx context.Context, task *ytdlp.Task) {
+ if info, ok := GetTask(p.taskID); ok {
+ info.UpdateStatus(TaskStatusRunning)
+ info.UpdatePhase("downloading")
+ }
+}
+
+func (p *ytdlpAPIProgress) OnProgress(ctx context.Context, task *ytdlp.Task, status string) {
+ if info, ok := GetTask(p.taskID); ok {
+ if strings.HasPrefix(strings.ToLower(status), "transferred") {
+ info.UpdatePhase("uploading")
+ } else {
+ info.UpdatePhase("downloading")
+ }
+ }
+}
+
+func (p *ytdlpAPIProgress) OnDone(ctx context.Context, task *ytdlp.Task, err error) {
+ finishAPITask(p.taskID, err)
+}
+
+type aria2APIProgress struct{ taskID string }
+
+func newAria2APIProgress(taskID string) aria2dl.ProgressTracker {
+ return &aria2APIProgress{taskID: taskID}
+}
+
+func (p *aria2APIProgress) OnStart(ctx context.Context, task *aria2dl.Task) {
+ if info, ok := GetTask(p.taskID); ok {
+ info.UpdateStatus(TaskStatusRunning)
+ info.UpdatePhase("downloading")
+ }
+}
+
+func (p *aria2APIProgress) OnProgress(ctx context.Context, task *aria2dl.Task, status *aria2.Status) {
+ if info, ok := GetTask(p.taskID); ok {
+ total, _ := strconv.ParseInt(status.TotalLength, 10, 64)
+ completed, _ := strconv.ParseInt(status.CompletedLength, 10, 64)
+ info.UpdatePhase("downloading")
+ info.UpdateDownloadProgress(completed, total)
+ }
+}
+
+func (p *aria2APIProgress) OnDone(ctx context.Context, task *aria2dl.Task, err error) {
+ finishAPITask(p.taskID, err)
+}
+
+type parsedAPIProgress struct{ taskID string }
+
+func newParsedAPIProgress(taskID string) parsed.ProgressTracker {
+ return &parsedAPIProgress{taskID: taskID}
+}
+
+func (p *parsedAPIProgress) OnStart(ctx context.Context, info parsed.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdateStatus(TaskStatusRunning)
+ task.UpdatePhase("downloading")
+ task.TotalFiles = int(info.TotalResources())
+ task.UpdateDownloadProgress(info.DownloadedBytes(), info.TotalBytes())
+ }
+}
+
+func (p *parsedAPIProgress) OnProgress(ctx context.Context, info parsed.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdatePhase("downloading")
+ task.TotalFiles = int(info.TotalResources())
+ task.DownloadedFiles = int(info.Downloaded())
+ task.UpdateDownloadProgress(info.DownloadedBytes(), info.TotalBytes())
+ }
+}
+
+func (p *parsedAPIProgress) OnDone(ctx context.Context, info parsed.TaskInfo, err error) {
+ finishAPITask(p.taskID, err)
+}
+
+type tfileAPIProgress struct{ taskID string }
+
+func newTFileAPIProgress(taskID string) tfile.ProgressTracker {
+ return &tfileAPIProgress{taskID: taskID}
+}
+
+func (p *tfileAPIProgress) OnStart(ctx context.Context, info tfile.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdateStatus(TaskStatusRunning)
+ task.UpdatePhase("downloading")
+ task.TotalFiles = 1
+ task.UpdateDownloadProgress(0, info.FileSize())
+ }
+}
+
+func (p *tfileAPIProgress) OnProgress(ctx context.Context, info tfile.TaskInfo, downloaded, total int64) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdatePhase("downloading")
+ task.UpdateDownloadProgress(downloaded, total)
+ }
+}
+
+func (p *tfileAPIProgress) OnDone(ctx context.Context, info tfile.TaskInfo, err error) {
+ finishAPITask(p.taskID, err)
+}
+
+type batchTFileAPIProgress struct{ taskID string }
+
+func newBatchTFileAPIProgress(taskID string) batchtfile.ProgressTracker {
+ return &batchTFileAPIProgress{taskID: taskID}
+}
+
+func (p *batchTFileAPIProgress) OnStart(ctx context.Context, info batchtfile.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdateStatus(TaskStatusRunning)
+ task.UpdatePhase("downloading")
+ task.TotalFiles = info.Count()
+ task.UpdateDownloadProgress(0, info.TotalSize())
+ }
+}
+
+func (p *batchTFileAPIProgress) OnProgress(ctx context.Context, info batchtfile.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdatePhase("downloading")
+ task.TotalFiles = info.Count()
+ task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize())
+ }
+}
+
+func (p *batchTFileAPIProgress) OnDone(ctx context.Context, info batchtfile.TaskInfo, err error) {
+ finishAPITask(p.taskID, err)
+}
+
+type telegraphAPIProgress struct{ taskID string }
+
+func newTelegraphAPIProgress(taskID string) tphtask.ProgressTracker {
+ return &telegraphAPIProgress{taskID: taskID}
+}
+
+func (p *telegraphAPIProgress) OnStart(ctx context.Context, info tphtask.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdateStatus(TaskStatusRunning)
+ task.UpdatePhase("downloading")
+ task.TotalFiles = info.TotalPics()
+ }
+}
+
+func (p *telegraphAPIProgress) OnProgress(ctx context.Context, info tphtask.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdatePhase("downloading")
+ task.DownloadedFiles = int(info.Downloaded())
+ }
+}
+
+func (p *telegraphAPIProgress) OnDone(ctx context.Context, info tphtask.TaskInfo, err error) {
+ finishAPITask(p.taskID, err)
+}
+
+type transferAPIProgress struct{ taskID string }
+
+func newTransferAPIProgress(taskID string) transfer.ProgressTracker {
+ return &transferAPIProgress{taskID: taskID}
+}
+
+func (p *transferAPIProgress) OnStart(ctx context.Context, info transfer.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdateStatus(TaskStatusRunning)
+ task.UpdatePhase(info.Phase())
+ task.TotalFiles = info.Count()
+ task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize())
+ task.UpdateUploadProgress(info.Uploaded(), info.TotalSize())
+ task.SetTransferMeta(info.SourceStorageName(), info.SourcePath(), info.TargetStorageName(), info.TargetPath())
+ }
+}
+
+func (p *transferAPIProgress) OnProgress(ctx context.Context, info transfer.TaskInfo) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdatePhase(info.Phase())
+ task.TotalFiles = info.Count()
+ task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize())
+ task.UpdateUploadProgress(info.Uploaded(), info.TotalSize())
+ task.SetTransferMeta(info.SourceStorageName(), info.SourcePath(), info.TargetStorageName(), info.TargetPath())
+ }
+}
+
+func (p *transferAPIProgress) OnDone(ctx context.Context, info transfer.TaskInfo, err error) {
+ if task, ok := GetTask(p.taskID); ok {
+ task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize())
+ task.UpdateUploadProgress(info.Uploaded(), info.TotalSize())
+ task.SetTransferMeta(info.SourceStorageName(), info.SourcePath(), info.TargetStorageName(), info.TargetPath())
+ }
+ finishAPITask(p.taskID, err)
+}
+
+func finishAPITask(taskID string, err error) {
+ info, ok := GetTask(taskID)
+ if !ok {
+ return
+ }
+ if err != nil {
+ if errors.Is(err, context.Canceled) {
+ if info.Status != TaskStatusPaused {
+ info.UpdateStatus(TaskStatusCancelled)
+ info.UpdatePhase("cancelled")
+ }
+ return
+ }
+ info.SetError(err.Error())
+ return
+ }
+ info.UpdateStatus(TaskStatusCompleted)
+ info.UpdatePhase("completed")
+}
diff --git a/api/types.go b/api/types.go
index 5462f6c6..cea532b6 100644
--- a/api/types.go
+++ b/api/types.go
@@ -14,6 +14,7 @@ type TaskStatus string
const (
TaskStatusQueued TaskStatus = "queued"
TaskStatusRunning TaskStatus = "running"
+ TaskStatusPaused TaskStatus = "paused"
TaskStatusCompleted TaskStatus = "completed"
TaskStatusFailed TaskStatus = "failed"
TaskStatusCancelled TaskStatus = "cancelled"
@@ -40,22 +41,28 @@ type CreateTaskResponse struct {
type TaskProgress struct {
TotalBytes int64 `json:"total_bytes,omitempty"`
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
+ UploadedBytes int64 `json:"uploaded_bytes,omitempty"`
Percent float64 `json:"percent,omitempty"`
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
}
// TaskInfoResponse 任务信息响应
type TaskInfoResponse struct {
- TaskID string `json:"task_id"`
- Type tasktype.TaskType `json:"type"`
- Status TaskStatus `json:"status"`
- Title string `json:"title"`
- Progress *TaskProgress `json:"progress,omitempty"`
- Storage string `json:"storage"`
- Path string `json:"path"`
- Error string `json:"error,omitempty"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
+ TaskID string `json:"task_id"`
+ Type tasktype.TaskType `json:"type"`
+ Status TaskStatus `json:"status"`
+ Title string `json:"title"`
+ Progress *TaskProgress `json:"progress,omitempty"`
+ Storage string `json:"storage"`
+ Path string `json:"path"`
+ SourceStorage string `json:"source_storage,omitempty"`
+ SourcePath string `json:"source_path,omitempty"`
+ TargetStorage string `json:"target_storage,omitempty"`
+ TargetPath string `json:"target_path,omitempty"`
+ Phase string `json:"phase,omitempty"`
+ Error string `json:"error,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
}
// TasksListResponse 任务列表响应
@@ -150,6 +157,22 @@ type TransferParams struct {
TargetPath string `json:"target_path"`
}
+type UpdateTaskPathRequest struct {
+ Path string `json:"path"`
+}
+
+type ProxyTestRequest struct {
+ URL string `json:"url"`
+ Target string `json:"target,omitempty"`
+}
+
+type ProxyTestResponse struct {
+ OK bool `json:"ok"`
+ MS int64 `json:"ms"`
+ Message string `json:"message,omitempty"`
+ Target string `json:"target"`
+}
+
// TGFilesParams tgfiles 任务参数
type TGFilesParams struct {
MessageLinks []string `json:"message_links"`
diff --git a/cmd/run.go b/cmd/run.go
index b1390a2d..00e96a2b 100644
--- a/cmd/run.go
+++ b/cmd/run.go
@@ -2,6 +2,7 @@ package cmd
import (
"context"
+ "io"
"os"
"path/filepath"
"strings"
@@ -15,6 +16,7 @@ import (
userclient "github.com/krau/SaveAny-Bot/client/user"
"github.com/krau/SaveAny-Bot/common/cache"
"github.com/krau/SaveAny-Bot/common/i18n"
+ "github.com/krau/SaveAny-Bot/common/logbuffer"
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
@@ -26,7 +28,7 @@ import (
func Run(cmd *cobra.Command, _ []string) {
ctx, cancel := context.WithCancel(cmd.Context())
- logger := log.NewWithOptions(os.Stdout, log.Options{
+ logger := log.NewWithOptions(io.MultiWriter(os.Stdout, logbuffer.Default()), log.Options{
Level: log.InfoLevel,
ReportTimestamp: true,
TimeFormat: time.TimeOnly,
@@ -51,10 +53,12 @@ func Run(cmd *cobra.Command, _ []string) {
if err != nil {
logger.Fatal("Init failed", "error", err)
}
- go func() {
- <-exitChan
- cancel()
- }()
+ if exitChan != nil {
+ go func() {
+ <-exitChan
+ cancel()
+ }()
+ }
core.Run(ctx)
@@ -89,6 +93,10 @@ func initAll(ctx context.Context) (<-chan struct{}, error) {
if err := api.Start(ctx); err != nil {
logger.Error("Failed to start API server", "error", err)
}
+ if strings.TrimSpace(config.C().Telegram.Token) == "" {
+ logger.Warn("Telegram bot token is empty; skip bot initialization and keep the config web/API server running")
+ return nil, nil
+ }
return bot.Init(ctx), nil
}
diff --git a/cmd/web.go b/cmd/web.go
new file mode 100644
index 00000000..54e5454d
--- /dev/null
+++ b/cmd/web.go
@@ -0,0 +1,67 @@
+package cmd
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "time"
+
+ "github.com/charmbracelet/log"
+ "github.com/krau/SaveAny-Bot/api"
+ "github.com/krau/SaveAny-Bot/common/logbuffer"
+ "github.com/krau/SaveAny-Bot/config"
+ "github.com/krau/SaveAny-Bot/database"
+ "github.com/spf13/cobra"
+)
+
+var webCmd = &cobra.Command{
+ Use: "web",
+ Short: "Start the visual configuration web UI",
+ Run: runWeb,
+}
+
+func init() {
+ webCmd.Flags().StringP("config", "c", config.DefaultConfigFile, "config file path")
+ webCmd.Flags().String("host", "0.0.0.0", "web UI listen host")
+ webCmd.Flags().Int("port", config.DefaultAPIPort, "web UI listen port")
+ webCmd.Flags().String("token", "", "web UI API token")
+ rootCmd.AddCommand(webCmd)
+}
+
+func runWeb(cmd *cobra.Command, _ []string) {
+ ctx := cmd.Context()
+ logger := log.NewWithOptions(io.MultiWriter(os.Stdout, logbuffer.Default()), log.Options{
+ Level: log.InfoLevel,
+ ReportTimestamp: true,
+ TimeFormat: time.TimeOnly,
+ ReportCaller: true,
+ })
+ log.SetDefault(logger)
+ ctx = log.WithContext(ctx, logger)
+
+ configPath, _ := cmd.Flags().GetString("config")
+ host, _ := cmd.Flags().GetString("host")
+ port, _ := cmd.Flags().GetInt("port")
+ token, _ := cmd.Flags().GetString("token")
+
+ if _, err := os.Stat(configPath); err == nil {
+ if err := config.Init(ctx, configPath); err != nil {
+ logger.Warn("Config file exists but could not be loaded; web editor will still start", "error", err)
+ } else if err := database.Open(ctx); err != nil {
+ logger.Warn("Database could not be opened; rule editor will become available after a valid config is saved", "error", err)
+ }
+ }
+ if token == "" && host != "127.0.0.1" && host != "localhost" {
+ logger.Warn("Config web UI is listening without a token", "host", host)
+ }
+ if _, err := api.StartConfigWebServer(ctx, api.ConfigWebServerOptions{
+ ConfigPath: configPath,
+ Host: host,
+ Port: port,
+ Token: token,
+ }); err != nil {
+ logger.Fatal("Failed to start config web server", "error", err)
+ }
+ fmt.Printf("Config web UI: http://%s:%d/config\n", host, port)
+ <-ctx.Done()
+}
diff --git a/common/logbuffer/logbuffer.go b/common/logbuffer/logbuffer.go
new file mode 100644
index 00000000..1dd38435
--- /dev/null
+++ b/common/logbuffer/logbuffer.go
@@ -0,0 +1,79 @@
+package logbuffer
+
+import (
+ "bytes"
+ "regexp"
+ "strings"
+ "sync"
+)
+
+const defaultLimit = 600
+
+var ansiPattern = regexp.MustCompile(`\x1b\[[0-9;]*[A-Za-z]`)
+
+type Buffer struct {
+ mu sync.RWMutex
+ limit int
+ lines []string
+ partial string
+}
+
+var defaultBuffer = New(defaultLimit)
+
+func Default() *Buffer {
+ return defaultBuffer
+}
+
+func New(limit int) *Buffer {
+ if limit < 1 {
+ limit = defaultLimit
+ }
+ return &Buffer{limit: limit}
+}
+
+func (b *Buffer) Write(p []byte) (int, error) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ text := b.partial + string(p)
+ parts := strings.Split(text, "\n")
+ b.partial = parts[len(parts)-1]
+ for _, line := range parts[:len(parts)-1] {
+ b.appendLine(line)
+ }
+ return len(p), nil
+}
+
+func (b *Buffer) Lines(limit int) []string {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ lines := b.lines
+ if b.partial != "" {
+ lines = append(append([]string{}, b.lines...), cleanLine(b.partial))
+ }
+ if limit < 1 || limit > len(lines) {
+ limit = len(lines)
+ }
+ out := make([]string, limit)
+ copy(out, lines[len(lines)-limit:])
+ return out
+}
+
+func (b *Buffer) appendLine(line string) {
+ line = cleanLine(line)
+ if strings.TrimSpace(line) == "" {
+ return
+ }
+ b.lines = append(b.lines, line)
+ if len(b.lines) > b.limit {
+ copy(b.lines, b.lines[len(b.lines)-b.limit:])
+ b.lines = b.lines[:b.limit]
+ }
+}
+
+func cleanLine(line string) string {
+ line = ansiPattern.ReplaceAllString(line, "")
+ line = strings.TrimRight(line, "\r")
+ return string(bytes.TrimRight([]byte(line), "\x00"))
+}
diff --git a/config/edit.go b/config/edit.go
new file mode 100644
index 00000000..2af45049
--- /dev/null
+++ b/config/edit.go
@@ -0,0 +1,743 @@
+package config
+
+import (
+ "fmt"
+ "net/url"
+ "os"
+ "path/filepath"
+ "regexp"
+ "slices"
+ "strconv"
+ "strings"
+
+ storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
+ "github.com/pelletier/go-toml/v2"
+ "github.com/spf13/viper"
+)
+
+const (
+ DefaultConfigFile = "config.toml"
+ DefaultAPIPort = 19191
+)
+
+type EditableConfig struct {
+ Lang string `toml:"lang" json:"lang"`
+ Workers int `toml:"workers" json:"workers"`
+ Retry int `toml:"retry" json:"retry"`
+ Threads int `toml:"threads" json:"threads"`
+ Stream bool `toml:"stream" json:"stream"`
+ NoCleanCache bool `toml:"no_clean_cache" json:"no_clean_cache"`
+ Proxy string `toml:"proxy" json:"proxy"`
+ Log EditableLogConfig `toml:"log" json:"log"`
+ Telegram EditableTelegramConfig `toml:"telegram" json:"telegram"`
+ Aria2 EditableAria2Config `toml:"aria2" json:"aria2"`
+ API EditableAPIConfig `toml:"api" json:"api"`
+ Cache EditableCacheConfig `toml:"cache" json:"cache"`
+ Temp EditableTempConfig `toml:"temp" json:"temp"`
+ DB EditableDBConfig `toml:"db" json:"db"`
+ Parser EditableParserConfig `toml:"parser" json:"parser"`
+ Hook EditableHookConfig `toml:"hook" json:"hook"`
+ Storages []map[string]any `toml:"storages" json:"storages"`
+ Users []EditableUserConfig `toml:"users" json:"users"`
+}
+
+type EditableLogConfig struct {
+ Level string `toml:"level" json:"level"`
+}
+
+type EditableTelegramConfig struct {
+ Token string `toml:"token" json:"token"`
+ AppID int `toml:"app_id" json:"app_id"`
+ AppHash string `toml:"app_hash" json:"app_hash"`
+ Proxy EditableTelegramProxyConfig `toml:"proxy" json:"proxy"`
+ RpcRetry int `toml:"rpc_retry" json:"rpc_retry"`
+ Userbot EditableUserbotConfig `toml:"userbot" json:"userbot"`
+ MediaGroupTimeout int `toml:"media_group_timeout" json:"media_group_timeout"`
+}
+
+type EditableTelegramProxyConfig struct {
+ Enable bool `toml:"enable" json:"enable"`
+ URL string `toml:"url" json:"url"`
+}
+
+type EditableUserbotConfig struct {
+ Enable bool `toml:"enable" json:"enable"`
+ Session string `toml:"session" json:"session"`
+}
+
+type EditableAria2Config struct {
+ Enable bool `toml:"enable" json:"enable"`
+ Url string `toml:"url" json:"url"`
+ Secret string `toml:"secret" json:"secret"`
+ KeepFile bool `toml:"keep_file" json:"keep_file"`
+}
+
+type EditableAPIConfig struct {
+ Enable bool `toml:"enable" json:"enable"`
+ Host string `toml:"host" json:"host"`
+ Port int `toml:"port" json:"port"`
+ Token string `toml:"token" json:"token"`
+}
+
+type EditableCacheConfig struct {
+ TTL int64 `toml:"ttl" json:"ttl"`
+ NumCounters int64 `toml:"num_counters" json:"num_counters"`
+ MaxCost int64 `toml:"max_cost" json:"max_cost"`
+}
+
+type EditableTempConfig struct {
+ BasePath string `toml:"base_path" json:"base_path"`
+}
+
+type EditableDBConfig struct {
+ Path string `toml:"path" json:"path"`
+ Session string `toml:"session" json:"session"`
+}
+
+type EditableParserConfig struct {
+ PluginEnable bool `toml:"plugin_enable" json:"plugin_enable"`
+ PluginDirs []string `toml:"plugin_dirs" json:"plugin_dirs"`
+ Proxy string `toml:"proxy" json:"proxy"`
+ ParserCfgs map[string]map[string]any `toml:",inline" json:"parser_cfgs,omitempty"`
+}
+
+type EditableHookConfig struct {
+ Exec EditableHookExecConfig `toml:"exec" json:"exec"`
+}
+
+type EditableHookExecConfig struct {
+ TaskBeforeStart string `toml:"task_before_start" json:"task_before_start"`
+ TaskSuccess string `toml:"task_success" json:"task_success"`
+ TaskFail string `toml:"task_fail" json:"task_fail"`
+ TaskCancel string `toml:"task_cancel" json:"task_cancel"`
+}
+
+type EditableUserConfig struct {
+ ID int64 `toml:"id" json:"id"`
+ Storages []string `toml:"storages" json:"storages"`
+ Blacklist bool `toml:"blacklist" json:"blacklist"`
+}
+
+type EditableConfigFile struct {
+ Path string `json:"path"`
+ Exists bool `json:"exists"`
+ Config EditableConfig `json:"config"`
+}
+
+type StorageFieldSchema struct {
+ Name string `json:"name"`
+ Label string `json:"label"`
+ Type string `json:"type"`
+ Required bool `json:"required"`
+ Secret bool `json:"secret,omitempty"`
+ Placeholder string `json:"placeholder,omitempty"`
+ Help string `json:"help,omitempty"`
+ Options []string `json:"options,omitempty"`
+}
+
+type StorageTypeSchema struct {
+ Type string `json:"type"`
+ Label string `json:"label"`
+ Fields []StorageFieldSchema `json:"fields"`
+}
+
+func ResolveConfigFilePath(path string) string {
+ if path != "" {
+ return path
+ }
+ if used := viper.ConfigFileUsed(); used != "" {
+ return used
+ }
+ return DefaultConfigFile
+}
+
+func ConfigFileUsed() string {
+ return viper.ConfigFileUsed()
+}
+
+func DefaultEditableConfig() EditableConfig {
+ return EditableConfig{
+ Lang: "zh-Hans",
+ Workers: 4,
+ Retry: 3,
+ Threads: 4,
+ Log: EditableLogConfig{
+ Level: "debug",
+ },
+ Telegram: EditableTelegramConfig{
+ AppID: 1025907,
+ AppHash: "452b0359b988148995f22ff0f4229750",
+ RpcRetry: 5,
+ Userbot: EditableUserbotConfig{
+ Session: "data/usersession.db",
+ },
+ },
+ Aria2: EditableAria2Config{
+ Url: "http://localhost:6800/jsonrpc",
+ },
+ API: EditableAPIConfig{
+ Host: "0.0.0.0",
+ Port: DefaultAPIPort,
+ },
+ Cache: EditableCacheConfig{
+ TTL: 86400,
+ NumCounters: 100000,
+ MaxCost: 1000000,
+ },
+ Temp: EditableTempConfig{
+ BasePath: "cache/",
+ },
+ DB: EditableDBConfig{
+ Path: "data/saveany.db",
+ Session: "data/session.db",
+ },
+ Parser: EditableParserConfig{
+ PluginDirs: []string{"plugins"},
+ },
+ Storages: []map[string]any{
+ {
+ "name": "local",
+ "type": storenum.Local.String(),
+ "enable": true,
+ "base_path": "./downloads",
+ },
+ },
+ Users: []EditableUserConfig{},
+ }
+}
+
+func LoadEditableConfig(path string) (*EditableConfigFile, error) {
+ path = ResolveConfigFilePath(path)
+ cfg := DefaultEditableConfig()
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return &EditableConfigFile{Path: path, Exists: false, Config: cfg}, nil
+ }
+ return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
+ }
+ if strings.TrimSpace(string(data)) == "" {
+ return &EditableConfigFile{Path: path, Exists: true, Config: cfg}, nil
+ }
+ if err := toml.Unmarshal(data, &cfg); err != nil {
+ return nil, fmt.Errorf("failed to parse config file %s: %w", path, err)
+ }
+ NormalizeEditableConfig(&cfg)
+ return &EditableConfigFile{Path: path, Exists: true, Config: cfg}, nil
+}
+
+func SaveEditableConfig(path string, cfg *EditableConfig) error {
+ path = ResolveConfigFilePath(path)
+ if isRemoteConfigPath(path) {
+ return fmt.Errorf("remote config files are read-only in the web editor")
+ }
+ NormalizeEditableConfig(cfg)
+ if err := ValidateEditableConfig(*cfg); err != nil {
+ return err
+ }
+ data, err := toml.Marshal(cfg)
+ if err != nil {
+ return fmt.Errorf("failed to encode config: %w", err)
+ }
+ if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil && filepath.Dir(path) != "." {
+ return fmt.Errorf("failed to create config directory: %w", err)
+ }
+ mode := os.FileMode(0644)
+ if st, err := os.Stat(path); err == nil {
+ mode = st.Mode().Perm()
+ if old, err := os.ReadFile(path); err == nil {
+ _ = os.WriteFile(path+".bak", old, mode)
+ }
+ }
+ tmp := path + ".tmp"
+ if err := os.WriteFile(tmp, data, mode); err != nil {
+ return fmt.Errorf("failed to write temporary config file: %w", err)
+ }
+ if err := os.Rename(tmp, path); err != nil {
+ _ = os.Remove(tmp)
+ return fmt.Errorf("failed to replace config file: %w", err)
+ }
+ return nil
+}
+
+func NormalizeEditableConfig(cfg *EditableConfig) {
+ defaults := DefaultEditableConfig()
+ if cfg.Lang == "" {
+ cfg.Lang = defaults.Lang
+ }
+ if cfg.Workers < 1 {
+ cfg.Workers = defaults.Workers
+ }
+ if cfg.Retry < 1 {
+ cfg.Retry = defaults.Retry
+ }
+ if cfg.Threads < 1 {
+ cfg.Threads = defaults.Threads
+ }
+ if cfg.Log.Level == "" {
+ cfg.Log.Level = defaults.Log.Level
+ }
+ if cfg.Telegram.AppID == 0 {
+ cfg.Telegram.AppID = defaults.Telegram.AppID
+ }
+ if cfg.Telegram.AppHash == "" {
+ cfg.Telegram.AppHash = defaults.Telegram.AppHash
+ }
+ if cfg.Telegram.RpcRetry == 0 {
+ cfg.Telegram.RpcRetry = defaults.Telegram.RpcRetry
+ }
+ if cfg.Telegram.Userbot.Session == "" {
+ cfg.Telegram.Userbot.Session = defaults.Telegram.Userbot.Session
+ }
+ if cfg.Aria2.Url == "" {
+ cfg.Aria2.Url = defaults.Aria2.Url
+ }
+ if cfg.API.Host == "" {
+ cfg.API.Host = defaults.API.Host
+ }
+ if cfg.API.Port == 0 {
+ cfg.API.Port = defaults.API.Port
+ }
+ if cfg.Cache.TTL == 0 {
+ cfg.Cache.TTL = defaults.Cache.TTL
+ }
+ if cfg.Cache.NumCounters == 0 {
+ cfg.Cache.NumCounters = defaults.Cache.NumCounters
+ }
+ if cfg.Cache.MaxCost == 0 {
+ cfg.Cache.MaxCost = defaults.Cache.MaxCost
+ }
+ if cfg.Temp.BasePath == "" {
+ cfg.Temp.BasePath = defaults.Temp.BasePath
+ }
+ if cfg.DB.Path == "" {
+ cfg.DB.Path = defaults.DB.Path
+ }
+ if cfg.DB.Session == "" {
+ cfg.DB.Session = defaults.DB.Session
+ }
+ if cfg.Parser.PluginDirs == nil {
+ cfg.Parser.PluginDirs = defaults.Parser.PluginDirs
+ }
+ if cfg.Storages == nil {
+ cfg.Storages = []map[string]any{}
+ }
+ if cfg.Users == nil {
+ cfg.Users = []EditableUserConfig{}
+ }
+ for i := range cfg.Storages {
+ normalizeStorageMap(cfg.Storages[i])
+ fillDefaultStoragePath(cfg.Storages[i])
+ }
+ storageNames := editableStorageNameSet(cfg.Storages)
+ for i := range cfg.Users {
+ if cfg.Users[i].Storages == nil {
+ cfg.Users[i].Storages = []string{}
+ }
+ cfg.Users[i].Storages = filterKnownStorageNames(cfg.Users[i].Storages, storageNames)
+ }
+}
+
+func ValidateEditableConfig(cfg EditableConfig) error {
+ if cfg.Workers < 1 {
+ return fmt.Errorf("workers must be greater than 0")
+ }
+ if cfg.Retry < 1 {
+ return fmt.Errorf("retry must be greater than 0")
+ }
+ if cfg.Threads < 1 {
+ return fmt.Errorf("threads must be greater than 0")
+ }
+ if !slices.Contains([]string{"trace", "debug", "info", "warn", "error", "fatal"}, strings.ToLower(cfg.Log.Level)) {
+ return fmt.Errorf("invalid log level: %s", cfg.Log.Level)
+ }
+ if err := validateProxyURL("proxy", cfg.Proxy, false); err != nil {
+ return err
+ }
+ if err := validateProxyURL("telegram.proxy.url", cfg.Telegram.Proxy.URL, cfg.Telegram.Proxy.Enable); err != nil {
+ return err
+ }
+ if err := validateProxyURL("parser.proxy", cfg.Parser.Proxy, false); err != nil {
+ return err
+ }
+ if cfg.API.Port < 1 || cfg.API.Port > 65535 {
+ return fmt.Errorf("api.port must be between 1 and 65535")
+ }
+ if cfg.Cache.TTL < 0 || cfg.Cache.NumCounters < 0 || cfg.Cache.MaxCost < 0 {
+ return fmt.Errorf("cache values must not be negative")
+ }
+ if cfg.DB.Path == "" {
+ return fmt.Errorf("db.path is required")
+ }
+ if cfg.DB.Session == "" {
+ return fmt.Errorf("db.session is required")
+ }
+ storageNames := make(map[string]struct{}, len(cfg.Storages))
+ for i, storage := range cfg.Storages {
+ if err := validateEditableStorage(i, storage); err != nil {
+ return err
+ }
+ name := getStringValue(storage, "name")
+ if _, ok := storageNames[name]; ok {
+ return fmt.Errorf("duplicate storage name: %s", name)
+ }
+ storageNames[name] = struct{}{}
+ }
+ userIDs := make(map[int64]struct{}, len(cfg.Users))
+ for _, user := range cfg.Users {
+ if user.ID == 0 {
+ return fmt.Errorf("user id is required")
+ }
+ if _, ok := userIDs[user.ID]; ok {
+ return fmt.Errorf("duplicate user id: %d", user.ID)
+ }
+ userIDs[user.ID] = struct{}{}
+ for _, storageName := range user.Storages {
+ if _, ok := storageNames[storageName]; !ok {
+ return fmt.Errorf("user %d references unknown storage %s", user.ID, storageName)
+ }
+ }
+ }
+ return nil
+}
+
+func StorageSchemas() []StorageTypeSchema {
+ return []StorageTypeSchema{
+ {
+ Type: storenum.Local.String(),
+ Label: "Local",
+ Fields: []StorageFieldSchema{
+ {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "./downloads"},
+ },
+ },
+ {
+ Type: storenum.Webdav.String(),
+ Label: "WebDAV",
+ Fields: []StorageFieldSchema{
+ {Name: "url", Label: "URL", Type: "url", Required: true, Placeholder: "https://example.com/dav"},
+ {Name: "username", Label: "用户名", Type: "string", Required: true},
+ {Name: "password", Label: "密码", Type: "password", Required: true, Secret: true},
+ {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "/telegram"},
+ },
+ },
+ {
+ Type: storenum.Alist.String(),
+ Label: "AList",
+ Fields: []StorageFieldSchema{
+ {Name: "url", Label: "URL", Type: "url", Required: true, Placeholder: "https://alist.example.com"},
+ {Name: "username", Label: "用户名", Type: "string"},
+ {Name: "password", Label: "密码", Type: "password", Secret: true},
+ {Name: "token", Label: "Token", Type: "password", Secret: true, Help: "Token 与用户名密码二选一"},
+ {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "/telegram"},
+ {Name: "token_exp", Label: "Token 过期时间", Type: "int"},
+ },
+ },
+ {
+ Type: storenum.Minio.String(),
+ Label: "MinIO",
+ Fields: objectStorageFields(false),
+ },
+ {
+ Type: storenum.S3.String(),
+ Label: "S3",
+ Fields: objectStorageFields(true),
+ },
+ {
+ Type: storenum.Telegram.String(),
+ Label: "Telegram",
+ Fields: []StorageFieldSchema{
+ {Name: "chat_id", Label: "Chat ID", Type: "int", Required: true, Placeholder: "-1001234567890"},
+ {Name: "force_file", Label: "强制文件模式", Type: "bool"},
+ {Name: "rate_limit", Label: "速率限制", Type: "int"},
+ {Name: "rate_burst", Label: "突发限制", Type: "int"},
+ {Name: "skip_large", Label: "跳过超大文件", Type: "bool"},
+ {Name: "split_size_mb", Label: "分卷大小 MB", Type: "int"},
+ },
+ },
+ {
+ Type: storenum.Rclone.String(),
+ Label: "Rclone",
+ Fields: []StorageFieldSchema{
+ {Name: "remote", Label: "Remote", Type: "string", Required: true, Placeholder: "remote:"},
+ {Name: "base_path", Label: "根路径", Type: "string", Placeholder: "/telegram"},
+ {Name: "config_path", Label: "配置文件路径", Type: "string"},
+ {Name: "flags", Label: "额外参数", Type: "string_list", Placeholder: "--transfers=4"},
+ },
+ },
+ }
+}
+
+func StorageTypeNames() []string {
+ return storenum.StorageTypeNames()
+}
+
+func objectStorageFields(includeS3Only bool) []StorageFieldSchema {
+ fields := []StorageFieldSchema{
+ {Name: "endpoint", Label: "Endpoint", Type: "string", Required: true, Placeholder: "s3.amazonaws.com"},
+ {Name: "access_key_id", Label: "Access Key ID", Type: "string", Required: true},
+ {Name: "secret_access_key", Label: "Secret Access Key", Type: "password", Required: true, Secret: true},
+ {Name: "bucket_name", Label: "Bucket", Type: "string", Required: true},
+ {Name: "use_ssl", Label: "Use SSL", Type: "bool"},
+ {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "telegram"},
+ }
+ if includeS3Only {
+ fields = append(fields,
+ StorageFieldSchema{Name: "region", Label: "Region", Type: "string", Placeholder: "ap-east-1"},
+ StorageFieldSchema{Name: "virtual_host", Label: "Virtual Host", Type: "bool"},
+ )
+ }
+ return fields
+}
+
+func validateEditableStorage(index int, storage map[string]any) error {
+ name := getStringValue(storage, "name")
+ if name == "" {
+ return fmt.Errorf("storages[%d].name is required", index)
+ }
+ storageType := getStringValue(storage, "type")
+ if storageType == "" {
+ return fmt.Errorf("storage %s type is required", name)
+ }
+ parsedType, err := storenum.ParseStorageType(storageType)
+ if err != nil {
+ return fmt.Errorf("invalid storage type %s for %s: %w", storageType, name, err)
+ }
+ if !getBoolValue(storage, "enable") {
+ return nil
+ }
+ required := func(key string) error {
+ if strings.TrimSpace(getStringValue(storage, key)) == "" {
+ return fmt.Errorf("%s is required for %s storage %s", key, parsedType, name)
+ }
+ return nil
+ }
+ switch parsedType {
+ case storenum.Local:
+ return required("base_path")
+ case storenum.Webdav:
+ for _, key := range []string{"url", "username", "password", "base_path"} {
+ if err := required(key); err != nil {
+ return err
+ }
+ }
+ case storenum.Alist:
+ if err := required("url"); err != nil {
+ return err
+ }
+ if getStringValue(storage, "token") == "" && (getStringValue(storage, "username") == "" || getStringValue(storage, "password") == "") {
+ return fmt.Errorf("username and password or token is required for alist storage %s", name)
+ }
+ return required("base_path")
+ case storenum.Minio, storenum.S3:
+ for _, key := range []string{"endpoint", "access_key_id", "secret_access_key", "bucket_name", "base_path"} {
+ if err := required(key); err != nil {
+ return err
+ }
+ }
+ case storenum.Telegram:
+ if getInt64Value(storage, "chat_id") == 0 {
+ return fmt.Errorf("chat_id is required for telegram storage %s", name)
+ }
+ if getInt64Value(storage, "rate_limit") < 0 || getInt64Value(storage, "rate_burst") < 0 {
+ return fmt.Errorf("rate_limit and rate_burst must not be negative for telegram storage %s", name)
+ }
+ case storenum.Rclone:
+ return required("remote")
+ }
+ return nil
+}
+
+func fillDefaultStoragePath(storage map[string]any) {
+ defaultPath, ok := defaultStorageBasePath(getStringValue(storage, "type"))
+ if !ok {
+ return
+ }
+ if strings.TrimSpace(getStringValue(storage, "base_path")) == "" {
+ storage["base_path"] = defaultPath
+ }
+}
+
+func defaultStorageBasePath(storageType string) (string, bool) {
+ switch storageType {
+ case storenum.Local.String():
+ return "./downloads", true
+ case storenum.Webdav.String(), storenum.Alist.String(), storenum.Rclone.String():
+ return "/telegram", true
+ case storenum.Minio.String(), storenum.S3.String():
+ return "telegram", true
+ default:
+ return "", false
+ }
+}
+
+func editableStorageNameSet(storages []map[string]any) map[string]struct{} {
+ names := make(map[string]struct{}, len(storages))
+ for _, storage := range storages {
+ name := getStringValue(storage, "name")
+ if name != "" {
+ names[name] = struct{}{}
+ }
+ }
+ return names
+}
+
+func validateProxyURL(field, value string, required bool) error {
+ value = strings.TrimSpace(value)
+ if value == "" {
+ if required {
+ return fmt.Errorf("%s is required", field)
+ }
+ return nil
+ }
+ u, err := url.Parse(value)
+ if err != nil {
+ return fmt.Errorf("invalid %s: %w", field, err)
+ }
+ switch strings.ToLower(u.Scheme) {
+ case "http", "https", "socks5", "socks5h":
+ return nil
+ default:
+ return fmt.Errorf("%s must use http, https, socks5, or socks5h", field)
+ }
+}
+
+func isRemoteConfigPath(path string) bool {
+ return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://")
+}
+
+func normalizeStorageMap(storage map[string]any) {
+ for _, schema := range StorageSchemas() {
+ if getStringValue(storage, "type") != schema.Type {
+ continue
+ }
+ for _, field := range schema.Fields {
+ switch field.Type {
+ case "int":
+ if value, ok := maybeInt64(storage[field.Name]); ok {
+ storage[field.Name] = value
+ }
+ case "bool":
+ if value, ok := maybeBool(storage[field.Name]); ok {
+ storage[field.Name] = value
+ }
+ case "string_list":
+ if value, ok := maybeStringList(storage[field.Name]); ok {
+ storage[field.Name] = value
+ }
+ default:
+ if value, ok := storage[field.Name]; ok {
+ storage[field.Name] = strings.TrimSpace(fmt.Sprint(value))
+ }
+ }
+ }
+ break
+ }
+ if value, ok := storage["enable"]; ok {
+ if b, ok := maybeBool(value); ok {
+ storage["enable"] = b
+ }
+ }
+}
+
+func getStringValue(values map[string]any, key string) string {
+ value, ok := values[key]
+ if !ok || value == nil {
+ return ""
+ }
+ return strings.TrimSpace(fmt.Sprint(value))
+}
+
+func getBoolValue(values map[string]any, key string) bool {
+ value, ok := values[key]
+ if !ok {
+ return false
+ }
+ b, _ := maybeBool(value)
+ return b
+}
+
+func getInt64Value(values map[string]any, key string) int64 {
+ value, ok := values[key]
+ if !ok {
+ return 0
+ }
+ i, _ := maybeInt64(value)
+ return i
+}
+
+func maybeBool(value any) (bool, bool) {
+ switch v := value.(type) {
+ case bool:
+ return v, true
+ case string:
+ b, err := strconv.ParseBool(strings.TrimSpace(v))
+ return b, err == nil
+ default:
+ return false, false
+ }
+}
+
+func maybeInt64(value any) (int64, bool) {
+ switch v := value.(type) {
+ case int:
+ return int64(v), true
+ case int8:
+ return int64(v), true
+ case int16:
+ return int64(v), true
+ case int32:
+ return int64(v), true
+ case int64:
+ return v, true
+ case uint:
+ return int64(v), true
+ case uint8:
+ return int64(v), true
+ case uint16:
+ return int64(v), true
+ case uint32:
+ return int64(v), true
+ case uint64:
+ if v > uint64(^uint64(0)>>1) {
+ return 0, false
+ }
+ return int64(v), true
+ case float32:
+ return int64(v), float32(int64(v)) == v
+ case float64:
+ return int64(v), float64(int64(v)) == v
+ case string:
+ i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
+ return i, err == nil
+ default:
+ return 0, false
+ }
+}
+
+func maybeStringList(value any) ([]string, bool) {
+ switch v := value.(type) {
+ case []string:
+ return v, true
+ case []any:
+ out := make([]string, 0, len(v))
+ for _, item := range v {
+ out = append(out, strings.TrimSpace(fmt.Sprint(item)))
+ }
+ return out, true
+ case string:
+ if strings.TrimSpace(v) == "" {
+ return []string{}, true
+ }
+ parts := regexp.MustCompile(`[\n,]+`).Split(v, -1)
+ out := make([]string, 0, len(parts))
+ for _, part := range parts {
+ part = strings.TrimSpace(part)
+ if part != "" {
+ out = append(out, part)
+ }
+ }
+ return out, true
+ default:
+ return nil, false
+ }
+}
diff --git a/config/storage/alist.go b/config/storage/alist.go
index 9449b245..05ab5513 100644
--- a/config/storage/alist.go
+++ b/config/storage/alist.go
@@ -36,3 +36,7 @@ func (a *AlistStorageConfig) GetType() storenum.StorageType {
func (a *AlistStorageConfig) GetName() string {
return a.Name
}
+
+func (a *AlistStorageConfig) SetDefaultBasePathIfEmpty(path string) {
+ setDefaultStringIfEmpty(&a.BasePath, path)
+}
diff --git a/config/storage/factory.go b/config/storage/factory.go
index 8deeab4b..0785b476 100644
--- a/config/storage/factory.go
+++ b/config/storage/factory.go
@@ -3,6 +3,7 @@ package storage
import (
"fmt"
"reflect"
+ "strings"
storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
"github.com/mitchellh/mapstructure"
@@ -29,10 +30,45 @@ func createStorageConfig(configType StorageConfig) func(cfg *BaseConfig) (Storag
return nil, fmt.Errorf("failed to decode %s storage config: %w", cfg.Type, err)
}
+ fillDefaultBasePath(configValue)
+
return configValue, nil
}
}
+func fillDefaultBasePath(cfg StorageConfig) {
+ basePathSetter, ok := cfg.(interface {
+ SetDefaultBasePathIfEmpty(string)
+ })
+ if !ok {
+ return
+ }
+ defaultPath, ok := defaultBasePath(cfg.GetType())
+ if !ok {
+ return
+ }
+ basePathSetter.SetDefaultBasePathIfEmpty(defaultPath)
+}
+
+func defaultBasePath(storageType storenum.StorageType) (string, bool) {
+ switch storageType {
+ case storenum.Local:
+ return "./downloads", true
+ case storenum.Webdav, storenum.Alist, storenum.Rclone:
+ return "/telegram", true
+ case storenum.Minio, storenum.S3:
+ return "telegram", true
+ default:
+ return "", false
+ }
+}
+
+func setDefaultStringIfEmpty(value *string, fallback string) {
+ if strings.TrimSpace(*value) == "" {
+ *value = fallback
+ }
+}
+
func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) {
var baseConfigs []BaseConfig
if err := v.UnmarshalKey("storages", &baseConfigs); err != nil {
diff --git a/config/storage/factory_test.go b/config/storage/factory_test.go
new file mode 100644
index 00000000..835cc408
--- /dev/null
+++ b/config/storage/factory_test.go
@@ -0,0 +1,64 @@
+package storage
+
+import (
+ "testing"
+
+ storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage"
+)
+
+func TestFillDefaultBasePath(t *testing.T) {
+ tests := []struct {
+ name string
+ cfg StorageConfig
+ expected string
+ }{
+ {name: "local", cfg: &LocalStorageConfig{BaseConfig: BaseConfig{Name: "local", Type: storenum.Local.String()}}, expected: "./downloads"},
+ {name: "webdav", cfg: &WebdavStorageConfig{BaseConfig: BaseConfig{Name: "webdav", Type: storenum.Webdav.String()}}, expected: "/telegram"},
+ {name: "alist", cfg: &AlistStorageConfig{BaseConfig: BaseConfig{Name: "alist", Type: storenum.Alist.String()}}, expected: "/telegram"},
+ {name: "minio", cfg: &MinioStorageConfig{BaseConfig: BaseConfig{Name: "minio", Type: storenum.Minio.String()}}, expected: "telegram"},
+ {name: "s3", cfg: &S3StorageConfig{BaseConfig: BaseConfig{Name: "s3", Type: storenum.S3.String()}}, expected: "telegram"},
+ {name: "rclone", cfg: &RcloneStorageConfig{BaseConfig: BaseConfig{Name: "rclone", Type: storenum.Rclone.String()}}, expected: "/telegram"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ fillDefaultBasePath(tt.cfg)
+ got := basePath(tt.cfg)
+ if got != tt.expected {
+ t.Fatalf("expected base path %q, got %q", tt.expected, got)
+ }
+ })
+ }
+}
+
+func TestFillDefaultBasePathKeepsExistingValue(t *testing.T) {
+ cfg := &LocalStorageConfig{
+ BaseConfig: BaseConfig{Name: "local", Type: storenum.Local.String()},
+ BasePath: "/custom",
+ }
+
+ fillDefaultBasePath(cfg)
+
+ if cfg.BasePath != "/custom" {
+ t.Fatalf("expected existing base path to remain, got %q", cfg.BasePath)
+ }
+}
+
+func basePath(cfg StorageConfig) string {
+ switch c := cfg.(type) {
+ case *LocalStorageConfig:
+ return c.BasePath
+ case *WebdavStorageConfig:
+ return c.BasePath
+ case *AlistStorageConfig:
+ return c.BasePath
+ case *MinioStorageConfig:
+ return c.BasePath
+ case *S3StorageConfig:
+ return c.BasePath
+ case *RcloneStorageConfig:
+ return c.BasePath
+ default:
+ return ""
+ }
+}
diff --git a/config/storage/local.go b/config/storage/local.go
index 46fbf29e..841dcc84 100644
--- a/config/storage/local.go
+++ b/config/storage/local.go
@@ -25,3 +25,7 @@ func (l *LocalStorageConfig) GetType() storenum.StorageType {
func (l *LocalStorageConfig) GetName() string {
return l.Name
}
+
+func (l *LocalStorageConfig) SetDefaultBasePathIfEmpty(path string) {
+ setDefaultStringIfEmpty(&l.BasePath, path)
+}
diff --git a/config/storage/minio.go b/config/storage/minio.go
index 8e9cd20a..d388e2d0 100644
--- a/config/storage/minio.go
+++ b/config/storage/minio.go
@@ -39,3 +39,7 @@ func (m *MinioStorageConfig) GetType() storenum.StorageType {
func (m *MinioStorageConfig) GetName() string {
return m.Name
}
+
+func (m *MinioStorageConfig) SetDefaultBasePathIfEmpty(path string) {
+ setDefaultStringIfEmpty(&m.BasePath, path)
+}
diff --git a/config/storage/rclone.go b/config/storage/rclone.go
index be17193e..1dd5ee7c 100644
--- a/config/storage/rclone.go
+++ b/config/storage/rclone.go
@@ -31,3 +31,7 @@ func (r *RcloneStorageConfig) GetType() storenum.StorageType {
func (r *RcloneStorageConfig) GetName() string {
return r.Name
}
+
+func (r *RcloneStorageConfig) SetDefaultBasePathIfEmpty(path string) {
+ setDefaultStringIfEmpty(&r.BasePath, path)
+}
diff --git a/config/storage/s3.go b/config/storage/s3.go
index 31e7a921..1850dd2d 100644
--- a/config/storage/s3.go
+++ b/config/storage/s3.go
@@ -41,3 +41,7 @@ func (m *S3StorageConfig) GetType() storenum.StorageType {
func (m *S3StorageConfig) GetName() string {
return m.Name
}
+
+func (m *S3StorageConfig) SetDefaultBasePathIfEmpty(path string) {
+ setDefaultStringIfEmpty(&m.BasePath, path)
+}
diff --git a/config/storage/webdav.go b/config/storage/webdav.go
index 93aaac59..2b4a8625 100644
--- a/config/storage/webdav.go
+++ b/config/storage/webdav.go
@@ -34,3 +34,7 @@ func (w *WebdavStorageConfig) GetType() storenum.StorageType {
func (w *WebdavStorageConfig) GetName() string {
return w.Name
}
+
+func (w *WebdavStorageConfig) SetDefaultBasePathIfEmpty(path string) {
+ setDefaultStringIfEmpty(&w.BasePath, path)
+}
diff --git a/config/user.go b/config/user.go
index ac97e0e2..50a36fa7 100644
--- a/config/user.go
+++ b/config/user.go
@@ -1,6 +1,8 @@
package config
import (
+ "strings"
+
"github.com/duke-git/lancet/v2/slice"
)
@@ -33,3 +35,26 @@ func (c Config) HasStorage(userID int64, storageName string) bool {
}
return slice.Contain(us, storageName)
}
+
+func filterKnownStorageNames(names []string, known map[string]struct{}) []string {
+ if len(names) == 0 {
+ return []string{}
+ }
+ filtered := make([]string, 0, len(names))
+ seen := make(map[string]struct{}, len(names))
+ for _, name := range names {
+ name = strings.TrimSpace(name)
+ if name == "" {
+ continue
+ }
+ if _, ok := known[name]; !ok {
+ continue
+ }
+ if _, ok := seen[name]; ok {
+ continue
+ }
+ seen[name] = struct{}{}
+ filtered = append(filtered, name)
+ }
+ return filtered
+}
diff --git a/config/user_test.go b/config/user_test.go
new file mode 100644
index 00000000..8a2a7981
--- /dev/null
+++ b/config/user_test.go
@@ -0,0 +1,16 @@
+package config
+
+import "testing"
+
+func TestFilterKnownStorageNames(t *testing.T) {
+ known := map[string]struct{}{
+ "local": {},
+ "s3": {},
+ }
+
+ got := filterKnownStorageNames([]string{"local", "missing", "s3", "local", ""}, known)
+
+ if len(got) != 2 || got[0] != "local" || got[1] != "s3" {
+ t.Fatalf("expected known storage names to remain once, got %#v", got)
+ }
+}
diff --git a/config/viper.go b/config/viper.go
index 43c7ffc5..9aa74313 100644
--- a/config/viper.go
+++ b/config/viper.go
@@ -129,7 +129,7 @@ func Init(ctx context.Context, configFile ...string) error {
// API
"api.enable": false,
"api.host": "0.0.0.0",
- "api.port": 8080,
+ "api.port": DefaultAPIPort,
"api.token": "",
}
@@ -179,7 +179,7 @@ func Init(ctx context.Context, configFile ...string) error {
if user.Blacklist {
userStorages[user.ID] = slice.Compact(slice.Difference(storages, user.Storages))
} else {
- userStorages[user.ID] = user.Storages
+ userStorages[user.ID] = filterKnownStorageNames(user.Storages, storageNames)
}
}
if cfg.Proxy != "" {
diff --git a/core/tasks/transfer/execute.go b/core/tasks/transfer/execute.go
index dc57da67..e78befd5 100644
--- a/core/tasks/transfer/execute.go
+++ b/core/tasks/transfer/execute.go
@@ -2,6 +2,7 @@ package transfer
import (
"context"
+ "errors"
"fmt"
"io"
"os"
@@ -9,17 +10,23 @@ import (
"path/filepath"
"github.com/charmbracelet/log"
+ "github.com/krau/SaveAny-Bot/common/utils/ioutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/storage"
"golang.org/x/sync/errgroup"
)
+var errTargetPathChanged = errors.New("target path changed")
+
// Execute implements core.Executable.
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("transfer[%s]", t.ID))
logger.Info("Starting transfer task")
- t.Progress.OnStart(ctx, t)
+ t.setPhase("running")
+ if t.Progress != nil {
+ t.Progress.OnStart(ctx, t)
+ }
workers := config.C().Workers
eg, gctx := errgroup.WithContext(ctx)
@@ -65,12 +72,15 @@ func (t *Task) Execute(ctx context.Context) error {
logger.Info("Transfer task completed successfully")
}
- t.Progress.OnDone(ctx, t, err)
+ if t.Progress != nil {
+ t.Progress.OnDone(ctx, t, err)
+ }
return err
}
func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.FileInfo.Name))
+ uploadedSize := elem.FileInfo.Size
// Check whether the source storage supports reading
readableStorage, ok := elem.SourceStorage.(storage.StorageReadable)
@@ -78,26 +88,38 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
return fmt.Errorf("source storage %s does not support reading", elem.SourceStorage.Name())
}
- logger.Info("Opening file from source storage")
- reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath)
- if err != nil {
- return fmt.Errorf("failed to open file: %w", err)
- }
- defer reader.Close()
-
- // Build target storage path: /target_path/filename
- storagePath := path.Join(elem.TargetPath, elem.FileInfo.Name)
-
- // Inject file size into context
- ctx = context.WithValue(ctx, ctxkey.ContentLength, size)
-
if config.C().Stream {
- if err := elem.TargetStorage.Save(ctx, reader, storagePath); err != nil {
- return fmt.Errorf("failed to upload file to storage: %w", err)
+ for {
+ logger.Info("Opening file from source storage")
+ reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath)
+ if err != nil {
+ return fmt.Errorf("failed to open file: %w", err)
+ }
+ uploadedSize = size
+ err = t.uploadReader(ctx, elem, reader, size)
+ closeErr := reader.Close()
+ if err == nil && closeErr != nil {
+ err = closeErr
+ }
+ if errors.Is(err, errTargetPathChanged) {
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("failed to upload file to storage: %w", err)
+ }
+ break
}
} else {
+ logger.Info("Opening file from source storage")
+ reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath)
+ if err != nil {
+ return fmt.Errorf("failed to open file: %w", err)
+ }
+ uploadedSize = size
+ defer reader.Close()
+
logger.Info("Downloading to temporary file for ReadSeeker support")
- tempFile, err := t.downloadToTemp(reader, elem.FileInfo.Name)
+ tempFile, err := t.downloadToTemp(ctx, reader, elem.FileInfo.Name)
if err != nil {
return fmt.Errorf("failed to download to temp: %w", err)
}
@@ -109,19 +131,30 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
}
logger.Infof("Uploading file to storage (size: %d bytes)", size)
- if err := elem.TargetStorage.Save(ctx, tempFile, storagePath); err != nil {
- return fmt.Errorf("failed to upload file to storage: %w", err)
+ for {
+ if _, err := tempFile.Seek(0, io.SeekStart); err != nil {
+ return fmt.Errorf("failed to seek temp file: %w", err)
+ }
+ if err := t.uploadReader(ctx, elem, tempFile, size); err != nil {
+ if errors.Is(err, errTargetPathChanged) {
+ continue
+ }
+ return fmt.Errorf("failed to upload file to storage: %w", err)
+ }
+ break
}
}
- t.uploaded.Add(size)
- t.Progress.OnProgress(ctx, t)
+ t.uploaded.Add(uploadedSize)
+ if t.Progress != nil {
+ t.Progress.OnProgress(ctx, t)
+ }
logger.Info("File uploaded successfully")
return nil
}
-func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, error) {
+func (t *Task) downloadToTemp(ctx context.Context, reader io.Reader, filename string) (*os.File, error) {
tempDir := config.C().Temp.BasePath
if tempDir == "" {
tempDir = os.TempDir()
@@ -132,7 +165,18 @@ func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, erro
return nil, fmt.Errorf("failed to create temp file: %w", err)
}
- if _, err := io.Copy(tempFile, reader); err != nil {
+ t.setPhase("downloading")
+ if t.Progress != nil {
+ t.Progress.OnProgress(ctx, t)
+ }
+ wr := ioutil.NewProgressWriter(tempFile, func(n int) {
+ t.downloaded.Add(int64(n))
+ if t.Progress != nil {
+ t.Progress.OnProgress(ctx, t)
+ }
+ })
+
+ if _, err := io.Copy(wr, reader); err != nil {
tempFile.Close()
os.Remove(tempFile.Name())
return nil, fmt.Errorf("failed to copy to temp file: %w", err)
@@ -140,3 +184,51 @@ func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, erro
return tempFile, nil
}
+
+func (t *Task) uploadReader(ctx context.Context, elem TaskElement, reader io.Reader, size int64) error {
+ version := t.pathVersion.Load()
+ uploadCtx, cancel := context.WithCancel(ctx)
+ t.registerUploadCancel(elem.ID, cancel)
+ defer func() {
+ cancel()
+ t.clearUploadCancel(elem.ID)
+ }()
+
+ t.setPhase("uploading")
+ if t.Progress != nil {
+ t.Progress.OnProgress(uploadCtx, t)
+ }
+ storagePath := path.Join(t.currentTargetPath(), elem.FileInfo.Name)
+ uploadCtx = context.WithValue(uploadCtx, ctxkey.ContentLength, size)
+ progressReader := ioutil.NewProgressReader(asReadSeeker(reader), size, func(read int64, total int64) {
+ t.setUploadingBytes(elem.ID, read)
+ if t.Progress != nil {
+ t.Progress.OnProgress(uploadCtx, t)
+ }
+ })
+ if err := elem.TargetStorage.Save(uploadCtx, progressReader, storagePath); err != nil {
+ if uploadCtx.Err() != nil && ctx.Err() == nil && version != t.pathVersion.Load() {
+ return errTargetPathChanged
+ }
+ return err
+ }
+ return nil
+}
+
+func asReadSeeker(reader io.Reader) io.ReadSeeker {
+ if rs, ok := reader.(io.ReadSeeker); ok {
+ return rs
+ }
+ return readSeekerAdapter{Reader: reader}
+}
+
+type readSeekerAdapter struct {
+ io.Reader
+}
+
+func (r readSeekerAdapter) Seek(offset int64, whence int) (int64, error) {
+ if offset == 0 && whence == io.SeekStart {
+ return 0, nil
+ }
+ return 0, fmt.Errorf("reader is not seekable")
+}
diff --git a/core/tasks/transfer/task.go b/core/tasks/transfer/task.go
index 85e2bdef..aefd6725 100644
--- a/core/tasks/transfer/task.go
+++ b/core/tasks/transfer/task.go
@@ -30,11 +30,20 @@ type Task struct {
elems []TaskElement
Progress ProgressTracker
IgnoreErrors bool
+ downloaded atomic.Int64
uploaded atomic.Int64
totalSize int64
processing map[string]TaskElementInfo
processingMu sync.RWMutex
failed map[string]error
+ targetPathMu sync.RWMutex
+ targetPath string
+ pathVersion atomic.Int64
+ phaseMu sync.RWMutex
+ phase string
+ uploadMu sync.RWMutex
+ uploading map[string]int64
+ uploadCancel map[string]context.CancelFunc
}
// Title implements core.Executable.
@@ -81,7 +90,14 @@ func NewTransferTask(
ctx: ctx,
elems: elems,
Progress: progress,
+ phase: "queued",
uploaded: atomic.Int64{},
+ targetPath: func() string {
+ if len(elems) == 0 {
+ return ""
+ }
+ return elems[0].TargetPath
+ }(),
totalSize: func() int64 {
var total int64
for _, elem := range elems {
@@ -92,6 +108,64 @@ func NewTransferTask(
processing: make(map[string]TaskElementInfo),
IgnoreErrors: ignoreErrors,
failed: make(map[string]error),
+ uploading: make(map[string]int64),
+ uploadCancel: make(map[string]context.CancelFunc),
}
return task
}
+
+func (t *Task) UpdateTargetPath(targetPath string) {
+ t.targetPathMu.Lock()
+ t.targetPath = targetPath
+ t.targetPathMu.Unlock()
+ t.pathVersion.Add(1)
+ t.uploaded.Store(0)
+
+ t.uploadMu.Lock()
+ t.uploading = make(map[string]int64)
+ for _, cancel := range t.uploadCancel {
+ cancel()
+ }
+ t.uploadMu.Unlock()
+}
+
+func (t *Task) currentTargetPath() string {
+ t.targetPathMu.RLock()
+ defer t.targetPathMu.RUnlock()
+ return t.targetPath
+}
+
+func (t *Task) setPhase(phase string) {
+ t.phaseMu.Lock()
+ t.phase = phase
+ t.phaseMu.Unlock()
+}
+
+func (t *Task) registerUploadCancel(id string, cancel context.CancelFunc) {
+ t.uploadMu.Lock()
+ defer t.uploadMu.Unlock()
+ t.uploadCancel[id] = cancel
+}
+
+func (t *Task) clearUploadCancel(id string) {
+ t.uploadMu.Lock()
+ defer t.uploadMu.Unlock()
+ delete(t.uploadCancel, id)
+ delete(t.uploading, id)
+}
+
+func (t *Task) setUploadingBytes(id string, n int64) {
+ t.uploadMu.Lock()
+ defer t.uploadMu.Unlock()
+ t.uploading[id] = n
+}
+
+func (t *Task) uploadedInFlight() int64 {
+ t.uploadMu.RLock()
+ defer t.uploadMu.RUnlock()
+ var total int64
+ for _, n := range t.uploading {
+ total += n
+ }
+ return total
+}
diff --git a/core/tasks/transfer/taskinfo.go b/core/tasks/transfer/taskinfo.go
index f9ec571e..e4d4e80a 100644
--- a/core/tasks/transfer/taskinfo.go
+++ b/core/tasks/transfer/taskinfo.go
@@ -26,18 +26,28 @@ func (e *TaskElement) SourceStorageName() string {
type TaskInfo interface {
TaskID() string
TotalSize() int64
+ Downloaded() int64
Uploaded() int64
Count() int
Processing() []TaskElementInfo
FailedFiles() []string
+ Phase() string
+ SourceStorageName() string
+ SourcePath() string
+ TargetStorageName() string
+ TargetPath() string
}
func (t *Task) TotalSize() int64 {
return t.totalSize
}
+func (t *Task) Downloaded() int64 {
+ return t.downloaded.Load()
+}
+
func (t *Task) Uploaded() int64 {
- return t.uploaded.Load()
+ return t.uploaded.Load() + t.uploadedInFlight()
}
func (t *Task) Count() int {
@@ -71,3 +81,34 @@ func (t *Task) FailedFiles() []string {
}
return result
}
+
+func (t *Task) Phase() string {
+ t.phaseMu.RLock()
+ defer t.phaseMu.RUnlock()
+ return t.phase
+}
+
+func (t *Task) SourceStorageName() string {
+ if len(t.elems) == 0 {
+ return ""
+ }
+ return t.elems[0].SourceStorage.Name()
+}
+
+func (t *Task) SourcePath() string {
+ if len(t.elems) == 0 {
+ return ""
+ }
+ return t.elems[0].SourcePath
+}
+
+func (t *Task) TargetStorageName() string {
+ if len(t.elems) == 0 {
+ return ""
+ }
+ return t.elems[0].TargetStorage.Name()
+}
+
+func (t *Task) TargetPath() string {
+ return t.currentTargetPath()
+}
diff --git a/database/db.go b/database/db.go
index 3c4300e3..b4b50385 100644
--- a/database/db.go
+++ b/database/db.go
@@ -9,6 +9,7 @@ import (
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
+ "github.com/krau/SaveAny-Bot/pkg/rule"
"gorm.io/gorm"
glogger "gorm.io/gorm/logger"
)
@@ -16,9 +17,15 @@ import (
var db *gorm.DB
func Init(ctx context.Context) {
+ if err := Open(ctx); err != nil {
+ log.FromContext(ctx).Fatal("Database initialization failed", "error", err)
+ }
+}
+
+func Open(ctx context.Context) error {
logger := log.FromContext(ctx)
if err := os.MkdirAll(filepath.Dir(config.C().DB.Path), 0755); err != nil {
- logger.Fatal("Failed to create data directory: ", err)
+ return fmt.Errorf("failed to create data directory: %w", err)
}
var err error
db, err = gorm.Open(GetDialect(config.C().DB.Path), &gorm.Config{
@@ -32,20 +39,28 @@ func Init(ctx context.Context) {
PrepareStmt: true,
})
if err != nil {
- logger.Fatal("Failed to open database: ", err)
+ return fmt.Errorf("failed to open database: %w", err)
}
logger.Debug("Database connected")
if err := db.AutoMigrate(&User{}, &Dir{}, &Rule{}, &WatchChat{}); err != nil {
- logger.Fatal("Database migration failed; if upgrading from an old version, try deleting the database file and retrying", "error", err)
+ return fmt.Errorf("database migration failed; if upgrading from an old version, try deleting the database file and retrying: %w", err)
}
- if err := syncUsers(ctx); err != nil {
- logger.Fatal("Failed to sync users:", err)
+ if err := SyncUsers(ctx); err != nil {
+ return fmt.Errorf("failed to sync users: %w", err)
}
logger.Debug("Database migrated")
logger.Info("Database initialized")
+ return nil
}
-func syncUsers(ctx context.Context) error {
+func Ready() bool {
+ return db != nil
+}
+
+func SyncUsers(ctx context.Context) error {
+ if db == nil {
+ return fmt.Errorf("database is not initialized")
+ }
logger := log.FromContext(ctx)
dbUsers, err := GetAllUsers(ctx)
if err != nil {
@@ -80,5 +95,72 @@ func syncUsers(ctx context.Context) error {
}
}
+ if err := cleanupUnavailableStorageReferences(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func cleanupUnavailableStorageReferences(ctx context.Context) error {
+ knownStorages := make(map[string]struct{}, len(config.C().Storages))
+ for _, storage := range config.C().Storages {
+ knownStorages[storage.GetName()] = struct{}{}
+ }
+
+ names := make(map[string]struct{})
+ for _, user := range config.C().Users {
+ for _, storageName := range user.Storages {
+ if _, ok := knownStorages[storageName]; !ok && storageName != "" {
+ names[storageName] = struct{}{}
+ }
+ }
+ }
+
+ var users []User
+ if err := db.WithContext(ctx).
+ Where("default_storage <> ''").
+ Find(&users).Error; err != nil {
+ return fmt.Errorf("failed to find users with default storage: %w", err)
+ }
+ for _, user := range users {
+ if _, ok := knownStorages[user.DefaultStorage]; !ok {
+ names[user.DefaultStorage] = struct{}{}
+ }
+ }
+
+ var dirs []Dir
+ if err := db.WithContext(ctx).
+ Select("storage_name").
+ Group("storage_name").
+ Find(&dirs).Error; err != nil {
+ return fmt.Errorf("failed to find storage dirs: %w", err)
+ }
+ for _, dir := range dirs {
+ if _, ok := knownStorages[dir.StorageName]; !ok && dir.StorageName != "" {
+ names[dir.StorageName] = struct{}{}
+ }
+ }
+
+ var rules []Rule
+ if err := db.WithContext(ctx).
+ Select("storage_name").
+ Where("storage_name <> '' AND storage_name <> ?", rule.RuleStorNameChosen).
+ Group("storage_name").
+ Find(&rules).Error; err != nil {
+ return fmt.Errorf("failed to find storage rules: %w", err)
+ }
+ for _, rule := range rules {
+ if _, ok := knownStorages[rule.StorageName]; !ok {
+ names[rule.StorageName] = struct{}{}
+ }
+ }
+
+ for name := range names {
+ if err := ClearStorageReferences(ctx, name); err != nil {
+ return fmt.Errorf("failed to clear references for storage %s: %w", name, err)
+ }
+ log.FromContext(ctx).Warnf("Cleared references to unavailable storage: %s", name)
+ }
return nil
}
diff --git a/database/storage_cleanup_test.go b/database/storage_cleanup_test.go
new file mode 100644
index 00000000..438098f3
--- /dev/null
+++ b/database/storage_cleanup_test.go
@@ -0,0 +1,150 @@
+package database
+
+import (
+ "context"
+ "testing"
+
+ "github.com/krau/SaveAny-Bot/pkg/rule"
+ "gorm.io/gorm"
+)
+
+func TestClearStorageReferences(t *testing.T) {
+ ctx := context.Background()
+ openTestDB(t)
+
+ user := User{
+ ChatID: 1001,
+ DefaultStorage: "removed",
+ Silent: true,
+ }
+ if err := db.WithContext(ctx).Create(&user).Error; err != nil {
+ t.Fatalf("failed to create user: %v", err)
+ }
+ dir := Dir{UserID: user.ID, StorageName: "removed", Path: "/old"}
+ otherDir := Dir{UserID: user.ID, StorageName: "kept", Path: "/new"}
+ if err := db.WithContext(ctx).Create(&dir).Error; err != nil {
+ t.Fatalf("failed to create removed dir: %v", err)
+ }
+ if err := db.WithContext(ctx).Create(&otherDir).Error; err != nil {
+ t.Fatalf("failed to create kept dir: %v", err)
+ }
+ user.DefaultDir = dir.ID
+ if err := db.WithContext(ctx).Save(&user).Error; err != nil {
+ t.Fatalf("failed to update user default dir: %v", err)
+ }
+ if err := db.WithContext(ctx).Create(&Rule{
+ UserID: user.ID,
+ Type: rule.FileNameRegex.String(),
+ Data: ".*",
+ StorageName: "removed",
+ DirPath: "/old",
+ }).Error; err != nil {
+ t.Fatalf("failed to create rule: %v", err)
+ }
+
+ if err := ClearStorageReferences(ctx, "removed"); err != nil {
+ t.Fatalf("cleanup failed: %v", err)
+ }
+
+ var gotUser User
+ if err := db.WithContext(ctx).First(&gotUser, user.ID).Error; err != nil {
+ t.Fatalf("failed to load user: %v", err)
+ }
+ if gotUser.DefaultStorage != "" || gotUser.DefaultDir != 0 || gotUser.Silent {
+ t.Fatalf("expected user defaults and silent to be cleared, got storage=%q dir=%d silent=%v", gotUser.DefaultStorage, gotUser.DefaultDir, gotUser.Silent)
+ }
+
+ var dirCount int64
+ if err := db.WithContext(ctx).Model(&Dir{}).Where("storage_name = ?", "removed").Count(&dirCount).Error; err != nil {
+ t.Fatalf("failed to count removed dirs: %v", err)
+ }
+ if dirCount != 0 {
+ t.Fatalf("expected removed storage dirs to be deleted, got %d", dirCount)
+ }
+ if err := db.WithContext(ctx).Model(&Dir{}).Where("storage_name = ?", "kept").Count(&dirCount).Error; err != nil {
+ t.Fatalf("failed to count kept dirs: %v", err)
+ }
+ if dirCount != 1 {
+ t.Fatalf("expected kept storage dir to remain, got %d", dirCount)
+ }
+
+ var gotRule Rule
+ if err := db.WithContext(ctx).Where("user_id = ?", user.ID).First(&gotRule).Error; err != nil {
+ t.Fatalf("failed to load rule: %v", err)
+ }
+ if gotRule.StorageName != rule.RuleStorNameChosen {
+ t.Fatalf("expected rule storage to be reset to %q, got %q", rule.RuleStorNameChosen, gotRule.StorageName)
+ }
+}
+
+func TestRenameStorageReferences(t *testing.T) {
+ ctx := context.Background()
+ openTestDB(t)
+
+ user := User{
+ ChatID: 1001,
+ DefaultStorage: "old",
+ Silent: true,
+ }
+ if err := db.WithContext(ctx).Create(&user).Error; err != nil {
+ t.Fatalf("failed to create user: %v", err)
+ }
+ dir := Dir{UserID: user.ID, StorageName: "old", Path: "/old"}
+ if err := db.WithContext(ctx).Create(&dir).Error; err != nil {
+ t.Fatalf("failed to create dir: %v", err)
+ }
+ user.DefaultDir = dir.ID
+ if err := db.WithContext(ctx).Save(&user).Error; err != nil {
+ t.Fatalf("failed to update user default dir: %v", err)
+ }
+ if err := db.WithContext(ctx).Create(&Rule{
+ UserID: user.ID,
+ Type: rule.FileNameRegex.String(),
+ Data: ".*",
+ StorageName: "old",
+ DirPath: "/old",
+ }).Error; err != nil {
+ t.Fatalf("failed to create rule: %v", err)
+ }
+
+ if err := RenameStorageReferences(ctx, "old", "new"); err != nil {
+ t.Fatalf("rename failed: %v", err)
+ }
+
+ var gotUser User
+ if err := db.WithContext(ctx).First(&gotUser, user.ID).Error; err != nil {
+ t.Fatalf("failed to load user: %v", err)
+ }
+ if gotUser.DefaultStorage != "new" || gotUser.DefaultDir != dir.ID || !gotUser.Silent {
+ t.Fatalf("expected user defaults to move to new storage, got storage=%q dir=%d silent=%v", gotUser.DefaultStorage, gotUser.DefaultDir, gotUser.Silent)
+ }
+
+ var gotDir Dir
+ if err := db.WithContext(ctx).First(&gotDir, dir.ID).Error; err != nil {
+ t.Fatalf("failed to load dir: %v", err)
+ }
+ if gotDir.StorageName != "new" {
+ t.Fatalf("expected dir storage to be renamed, got %q", gotDir.StorageName)
+ }
+
+ var gotRule Rule
+ if err := db.WithContext(ctx).Where("user_id = ?", user.ID).First(&gotRule).Error; err != nil {
+ t.Fatalf("failed to load rule: %v", err)
+ }
+ if gotRule.StorageName != "new" {
+ t.Fatalf("expected rule storage to be renamed, got %q", gotRule.StorageName)
+ }
+}
+
+func openTestDB(t *testing.T) {
+ t.Helper()
+ var err error
+ db, err = gorm.Open(GetDialect(t.TempDir()+"/saveany.db"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open test database: %v", err)
+ }
+ if err := db.AutoMigrate(&User{}, &Dir{}, &Rule{}, &WatchChat{}); err != nil {
+ t.Fatalf("failed to migrate test database: %v", err)
+ }
+ t.Cleanup(func() { db = nil })
+}
diff --git a/database/user.go b/database/user.go
index 7f631b5e..75ee70fe 100644
--- a/database/user.go
+++ b/database/user.go
@@ -2,7 +2,10 @@ package database
import (
"context"
+ "fmt"
+ "github.com/krau/SaveAny-Bot/pkg/rule"
+ "gorm.io/gorm"
"gorm.io/gorm/clause"
)
@@ -50,3 +53,66 @@ func GetUserByID(ctx context.Context, id uint) (*User, error) {
Where("id = ?", id).First(&user).Error
return &user, err
}
+
+func ClearStorageReferences(ctx context.Context, storageName string) error {
+ if db == nil {
+ return fmt.Errorf("database is not initialized")
+ }
+ if storageName == "" {
+ return nil
+ }
+ return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ if err := tx.Model(&User{}).
+ Where("default_storage = ?", storageName).
+ Updates(map[string]any{
+ "default_storage": "",
+ "default_dir": 0,
+ "silent": false,
+ }).Error; err != nil {
+ return err
+ }
+ if err := tx.Model(&User{}).
+ Where("default_dir IN (?)", tx.Model(&Dir{}).Select("id").Where("storage_name = ?", storageName)).
+ Update("default_dir", 0).Error; err != nil {
+ return err
+ }
+ if err := tx.Unscoped().
+ Where("storage_name = ?", storageName).
+ Delete(&Dir{}).Error; err != nil {
+ return err
+ }
+ if err := tx.Model(&Rule{}).
+ Where("storage_name = ?", storageName).
+ Update("storage_name", rule.RuleStorNameChosen).Error; err != nil {
+ return err
+ }
+ return nil
+ })
+}
+
+func RenameStorageReferences(ctx context.Context, oldName, newName string) error {
+ if db == nil {
+ return fmt.Errorf("database is not initialized")
+ }
+ if oldName == "" || newName == "" || oldName == newName {
+ return nil
+ }
+ return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ if err := tx.Model(&User{}).
+ Where("default_storage = ?", oldName).
+ Update("default_storage", newName).Error; err != nil {
+ return err
+ }
+ if err := tx.Model(&Dir{}).
+ Where("storage_name = ?", oldName).
+ Update("storage_name", newName).Error; err != nil {
+ return err
+ }
+ if err := tx.Model(&Rule{}).
+ Where("storage_name = ?", oldName).
+ Update("storage_name", newName).Error; err != nil {
+ return err
+ }
+ return nil
+ })
+}
diff --git a/storage/load.go b/storage/load.go
index 09bf4bf4..e5005ff7 100644
--- a/storage/load.go
+++ b/storage/load.go
@@ -2,7 +2,9 @@ package storage
import (
"context"
+ "errors"
"fmt"
+ "sync"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
@@ -10,6 +12,7 @@ import (
)
var UserStorages = make(map[int64][]Storage)
+var storagesMu sync.RWMutex
// GetStorageByName returns storage by name from cache or creates new one
// It should NOT be used to get storage for user, use GetStorageByUserIDAndName instead
@@ -18,10 +21,13 @@ func GetStorageByName(ctx context.Context, name string) (Storage, error) {
return nil, ErrStorageNameEmpty
}
+ storagesMu.RLock()
storage, ok := Storages[name]
+ storagesMu.RUnlock()
if ok {
return storage, nil
}
+
cfg := config.C().GetStorageByName(name)
if cfg == nil {
return nil, fmt.Errorf("未找到存储 %s", name)
@@ -31,7 +37,14 @@ func GetStorageByName(ctx context.Context, name string) (Storage, error) {
if err != nil {
return nil, err
}
+
+ storagesMu.Lock()
+ if existing, ok := Storages[name]; ok {
+ storagesMu.Unlock()
+ return existing, nil
+ }
Storages[name] = storage
+ storagesMu.Unlock()
return storage, nil
}
@@ -52,9 +65,14 @@ func GetUserStorages(ctx context.Context, chatID int64) []Storage {
if chatID <= 0 {
return nil
}
+ storagesMu.RLock()
if storages, ok := UserStorages[chatID]; ok {
- return storages
+ out := append([]Storage(nil), storages...)
+ storagesMu.RUnlock()
+ return out
}
+ storagesMu.RUnlock()
+
var storages []Storage
for _, name := range config.C().GetStorageNamesByUserID(chatID) {
storage, err := GetStorageByName(ctx, name)
@@ -63,22 +81,64 @@ func GetUserStorages(ctx context.Context, chatID int64) []Storage {
}
storages = append(storages, storage)
}
- return storages
+ storagesMu.Lock()
+ UserStorages[chatID] = storages
+ out := append([]Storage(nil), storages...)
+ storagesMu.Unlock()
+ return out
}
func LoadStorages(ctx context.Context) {
+ logger := log.FromContext(ctx)
+ if err := ReloadStorages(ctx); err != nil {
+ logger.Errorf("failed to load some storages: %v", err)
+ }
+}
+
+func ReloadStorages(ctx context.Context) error {
logger := log.FromContext(ctx)
logger.Debug("loading storages...")
- for _, storage := range config.C().Storages {
- _, err := GetStorageByName(ctx, storage.GetName())
+
+ nextStorages := make(map[string]Storage)
+ var errs []error
+ for _, cfg := range config.C().Storages {
+ storage, err := NewStorage(ctx, cfg)
if err != nil {
- logger.Errorf("failed to load storage %s: %v", storage.GetName(), err)
+ errs = append(errs, fmt.Errorf("%s: %w", cfg.GetName(), err))
+ logger.Errorf("failed to load storage %s: %v", cfg.GetName(), err)
+ continue
}
+ nextStorages[cfg.GetName()] = storage
}
- logger.Infof("successfully loaded %d storages", len(Storages))
- for user := range config.C().GetUsersID() {
- UserStorages[int64(user)] = GetUserStorages(ctx, int64(user))
+
+ nextUserStorages := make(map[int64][]Storage)
+ for _, userID := range config.C().GetUsersID() {
+ for _, name := range config.C().GetStorageNamesByUserID(userID) {
+ storage, ok := nextStorages[name]
+ if ok {
+ nextUserStorages[userID] = append(nextUserStorages[userID], storage)
+ }
+ }
+ }
+
+ storagesMu.Lock()
+ Storages = nextStorages
+ UserStorages = nextUserStorages
+ storagesMu.Unlock()
+
+ logger.Infof("successfully loaded %d storages", len(nextStorages))
+ return errors.Join(errs...)
+}
+
+func Snapshot() map[string]Storage {
+ storagesMu.RLock()
+ defer storagesMu.RUnlock()
+
+ out := make(map[string]Storage, len(Storages))
+ for name, storage := range Storages {
+ out[name] = storage
}
+ return out
}
// GetTelegramStorageByUserID returns the first enabled Telegram storage for the user