From 8dd6f954c1f22f34779bca4a3827f6e7465f01be Mon Sep 17 00:00:00 2001 From: liyw0205 Date: Mon, 25 May 2026 11:51:28 +0800 Subject: [PATCH 1/3] fix(config): clean storage references --- config/storage/alist.go | 4 ++ config/storage/factory.go | 36 +++++++++++++ config/storage/factory_test.go | 64 ++++++++++++++++++++++ config/storage/local.go | 4 ++ config/storage/minio.go | 4 ++ config/storage/rclone.go | 4 ++ config/storage/s3.go | 4 ++ config/storage/webdav.go | 4 ++ config/user.go | 25 +++++++++ config/user_test.go | 16 ++++++ config/viper.go | 2 +- database/db.go | 68 ++++++++++++++++++++++++ database/storage_cleanup_test.go | 91 ++++++++++++++++++++++++++++++++ database/user.go | 39 ++++++++++++++ 14 files changed, 364 insertions(+), 1 deletion(-) create mode 100644 config/storage/factory_test.go create mode 100644 config/user_test.go create mode 100644 database/storage_cleanup_test.go diff --git a/config/storage/alist.go b/config/storage/alist.go index 9449b245..05ab5513 100644 --- a/config/storage/alist.go +++ b/config/storage/alist.go @@ -36,3 +36,7 @@ func (a *AlistStorageConfig) GetType() storenum.StorageType { func (a *AlistStorageConfig) GetName() string { return a.Name } + +func (a *AlistStorageConfig) SetDefaultBasePathIfEmpty(path string) { + setDefaultStringIfEmpty(&a.BasePath, path) +} diff --git a/config/storage/factory.go b/config/storage/factory.go index 8deeab4b..0785b476 100644 --- a/config/storage/factory.go +++ b/config/storage/factory.go @@ -3,6 +3,7 @@ package storage import ( "fmt" "reflect" + "strings" storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" "github.com/mitchellh/mapstructure" @@ -29,10 +30,45 @@ func createStorageConfig(configType StorageConfig) func(cfg *BaseConfig) (Storag return nil, fmt.Errorf("failed to decode %s storage config: %w", cfg.Type, err) } + fillDefaultBasePath(configValue) + return configValue, nil } } +func fillDefaultBasePath(cfg StorageConfig) { + basePathSetter, ok := cfg.(interface { + SetDefaultBasePathIfEmpty(string) + }) + if !ok { + return + } + defaultPath, ok := defaultBasePath(cfg.GetType()) + if !ok { + return + } + basePathSetter.SetDefaultBasePathIfEmpty(defaultPath) +} + +func defaultBasePath(storageType storenum.StorageType) (string, bool) { + switch storageType { + case storenum.Local: + return "./downloads", true + case storenum.Webdav, storenum.Alist, storenum.Rclone: + return "/telegram", true + case storenum.Minio, storenum.S3: + return "telegram", true + default: + return "", false + } +} + +func setDefaultStringIfEmpty(value *string, fallback string) { + if strings.TrimSpace(*value) == "" { + *value = fallback + } +} + func LoadStorageConfigs(v *viper.Viper) ([]StorageConfig, error) { var baseConfigs []BaseConfig if err := v.UnmarshalKey("storages", &baseConfigs); err != nil { diff --git a/config/storage/factory_test.go b/config/storage/factory_test.go new file mode 100644 index 00000000..835cc408 --- /dev/null +++ b/config/storage/factory_test.go @@ -0,0 +1,64 @@ +package storage + +import ( + "testing" + + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" +) + +func TestFillDefaultBasePath(t *testing.T) { + tests := []struct { + name string + cfg StorageConfig + expected string + }{ + {name: "local", cfg: &LocalStorageConfig{BaseConfig: BaseConfig{Name: "local", Type: storenum.Local.String()}}, expected: "./downloads"}, + {name: "webdav", cfg: &WebdavStorageConfig{BaseConfig: BaseConfig{Name: "webdav", Type: storenum.Webdav.String()}}, expected: "/telegram"}, + {name: "alist", cfg: &AlistStorageConfig{BaseConfig: BaseConfig{Name: "alist", Type: storenum.Alist.String()}}, expected: "/telegram"}, + {name: "minio", cfg: &MinioStorageConfig{BaseConfig: BaseConfig{Name: "minio", Type: storenum.Minio.String()}}, expected: "telegram"}, + {name: "s3", cfg: &S3StorageConfig{BaseConfig: BaseConfig{Name: "s3", Type: storenum.S3.String()}}, expected: "telegram"}, + {name: "rclone", cfg: &RcloneStorageConfig{BaseConfig: BaseConfig{Name: "rclone", Type: storenum.Rclone.String()}}, expected: "/telegram"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fillDefaultBasePath(tt.cfg) + got := basePath(tt.cfg) + if got != tt.expected { + t.Fatalf("expected base path %q, got %q", tt.expected, got) + } + }) + } +} + +func TestFillDefaultBasePathKeepsExistingValue(t *testing.T) { + cfg := &LocalStorageConfig{ + BaseConfig: BaseConfig{Name: "local", Type: storenum.Local.String()}, + BasePath: "/custom", + } + + fillDefaultBasePath(cfg) + + if cfg.BasePath != "/custom" { + t.Fatalf("expected existing base path to remain, got %q", cfg.BasePath) + } +} + +func basePath(cfg StorageConfig) string { + switch c := cfg.(type) { + case *LocalStorageConfig: + return c.BasePath + case *WebdavStorageConfig: + return c.BasePath + case *AlistStorageConfig: + return c.BasePath + case *MinioStorageConfig: + return c.BasePath + case *S3StorageConfig: + return c.BasePath + case *RcloneStorageConfig: + return c.BasePath + default: + return "" + } +} diff --git a/config/storage/local.go b/config/storage/local.go index 46fbf29e..841dcc84 100644 --- a/config/storage/local.go +++ b/config/storage/local.go @@ -25,3 +25,7 @@ func (l *LocalStorageConfig) GetType() storenum.StorageType { func (l *LocalStorageConfig) GetName() string { return l.Name } + +func (l *LocalStorageConfig) SetDefaultBasePathIfEmpty(path string) { + setDefaultStringIfEmpty(&l.BasePath, path) +} diff --git a/config/storage/minio.go b/config/storage/minio.go index 8e9cd20a..d388e2d0 100644 --- a/config/storage/minio.go +++ b/config/storage/minio.go @@ -39,3 +39,7 @@ func (m *MinioStorageConfig) GetType() storenum.StorageType { func (m *MinioStorageConfig) GetName() string { return m.Name } + +func (m *MinioStorageConfig) SetDefaultBasePathIfEmpty(path string) { + setDefaultStringIfEmpty(&m.BasePath, path) +} diff --git a/config/storage/rclone.go b/config/storage/rclone.go index be17193e..1dd5ee7c 100644 --- a/config/storage/rclone.go +++ b/config/storage/rclone.go @@ -31,3 +31,7 @@ func (r *RcloneStorageConfig) GetType() storenum.StorageType { func (r *RcloneStorageConfig) GetName() string { return r.Name } + +func (r *RcloneStorageConfig) SetDefaultBasePathIfEmpty(path string) { + setDefaultStringIfEmpty(&r.BasePath, path) +} diff --git a/config/storage/s3.go b/config/storage/s3.go index 31e7a921..1850dd2d 100644 --- a/config/storage/s3.go +++ b/config/storage/s3.go @@ -41,3 +41,7 @@ func (m *S3StorageConfig) GetType() storenum.StorageType { func (m *S3StorageConfig) GetName() string { return m.Name } + +func (m *S3StorageConfig) SetDefaultBasePathIfEmpty(path string) { + setDefaultStringIfEmpty(&m.BasePath, path) +} diff --git a/config/storage/webdav.go b/config/storage/webdav.go index 93aaac59..2b4a8625 100644 --- a/config/storage/webdav.go +++ b/config/storage/webdav.go @@ -34,3 +34,7 @@ func (w *WebdavStorageConfig) GetType() storenum.StorageType { func (w *WebdavStorageConfig) GetName() string { return w.Name } + +func (w *WebdavStorageConfig) SetDefaultBasePathIfEmpty(path string) { + setDefaultStringIfEmpty(&w.BasePath, path) +} diff --git a/config/user.go b/config/user.go index ac97e0e2..50a36fa7 100644 --- a/config/user.go +++ b/config/user.go @@ -1,6 +1,8 @@ package config import ( + "strings" + "github.com/duke-git/lancet/v2/slice" ) @@ -33,3 +35,26 @@ func (c Config) HasStorage(userID int64, storageName string) bool { } return slice.Contain(us, storageName) } + +func filterKnownStorageNames(names []string, known map[string]struct{}) []string { + if len(names) == 0 { + return []string{} + } + filtered := make([]string, 0, len(names)) + seen := make(map[string]struct{}, len(names)) + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" { + continue + } + if _, ok := known[name]; !ok { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + filtered = append(filtered, name) + } + return filtered +} diff --git a/config/user_test.go b/config/user_test.go new file mode 100644 index 00000000..8a2a7981 --- /dev/null +++ b/config/user_test.go @@ -0,0 +1,16 @@ +package config + +import "testing" + +func TestFilterKnownStorageNames(t *testing.T) { + known := map[string]struct{}{ + "local": {}, + "s3": {}, + } + + got := filterKnownStorageNames([]string{"local", "missing", "s3", "local", ""}, known) + + if len(got) != 2 || got[0] != "local" || got[1] != "s3" { + t.Fatalf("expected known storage names to remain once, got %#v", got) + } +} diff --git a/config/viper.go b/config/viper.go index 43c7ffc5..af0de879 100644 --- a/config/viper.go +++ b/config/viper.go @@ -179,7 +179,7 @@ func Init(ctx context.Context, configFile ...string) error { if user.Blacklist { userStorages[user.ID] = slice.Compact(slice.Difference(storages, user.Storages)) } else { - userStorages[user.ID] = user.Storages + userStorages[user.ID] = filterKnownStorageNames(user.Storages, storageNames) } } if cfg.Proxy != "" { diff --git a/database/db.go b/database/db.go index 3c4300e3..71870eab 100644 --- a/database/db.go +++ b/database/db.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/log" "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/rule" "gorm.io/gorm" glogger "gorm.io/gorm/logger" ) @@ -80,5 +81,72 @@ func syncUsers(ctx context.Context) error { } } + if err := cleanupUnavailableStorageReferences(ctx); err != nil { + return err + } + + return nil +} + +func cleanupUnavailableStorageReferences(ctx context.Context) error { + knownStorages := make(map[string]struct{}, len(config.C().Storages)) + for _, storage := range config.C().Storages { + knownStorages[storage.GetName()] = struct{}{} + } + + names := make(map[string]struct{}) + for _, user := range config.C().Users { + for _, storageName := range user.Storages { + if _, ok := knownStorages[storageName]; !ok && storageName != "" { + names[storageName] = struct{}{} + } + } + } + + var users []User + if err := db.WithContext(ctx). + Where("default_storage <> ''"). + Find(&users).Error; err != nil { + return fmt.Errorf("failed to find users with default storage: %w", err) + } + for _, user := range users { + if _, ok := knownStorages[user.DefaultStorage]; !ok { + names[user.DefaultStorage] = struct{}{} + } + } + + var dirs []Dir + if err := db.WithContext(ctx). + Select("storage_name"). + Group("storage_name"). + Find(&dirs).Error; err != nil { + return fmt.Errorf("failed to find storage dirs: %w", err) + } + for _, dir := range dirs { + if _, ok := knownStorages[dir.StorageName]; !ok && dir.StorageName != "" { + names[dir.StorageName] = struct{}{} + } + } + + var rules []Rule + if err := db.WithContext(ctx). + Select("storage_name"). + Where("storage_name <> '' AND storage_name <> ?", rule.RuleStorNameChosen). + Group("storage_name"). + Find(&rules).Error; err != nil { + return fmt.Errorf("failed to find storage rules: %w", err) + } + for _, rule := range rules { + if _, ok := knownStorages[rule.StorageName]; !ok { + names[rule.StorageName] = struct{}{} + } + } + + for name := range names { + if err := ClearStorageReferences(ctx, name); err != nil { + return fmt.Errorf("failed to clear references for storage %s: %w", name, err) + } + log.FromContext(ctx).Warnf("Cleared references to unavailable storage: %s", name) + } return nil } diff --git a/database/storage_cleanup_test.go b/database/storage_cleanup_test.go new file mode 100644 index 00000000..a1c62c29 --- /dev/null +++ b/database/storage_cleanup_test.go @@ -0,0 +1,91 @@ +package database + +import ( + "context" + "testing" + + "github.com/krau/SaveAny-Bot/pkg/rule" + "gorm.io/gorm" +) + +func TestClearStorageReferences(t *testing.T) { + ctx := context.Background() + openTestDB(t) + + user := User{ + ChatID: 1001, + DefaultStorage: "removed", + Silent: true, + } + if err := db.WithContext(ctx).Create(&user).Error; err != nil { + t.Fatalf("failed to create user: %v", err) + } + dir := Dir{UserID: user.ID, StorageName: "removed", Path: "/old"} + otherDir := Dir{UserID: user.ID, StorageName: "kept", Path: "/new"} + if err := db.WithContext(ctx).Create(&dir).Error; err != nil { + t.Fatalf("failed to create removed dir: %v", err) + } + if err := db.WithContext(ctx).Create(&otherDir).Error; err != nil { + t.Fatalf("failed to create kept dir: %v", err) + } + user.DefaultDir = dir.ID + if err := db.WithContext(ctx).Save(&user).Error; err != nil { + t.Fatalf("failed to update user default dir: %v", err) + } + if err := db.WithContext(ctx).Create(&Rule{ + UserID: user.ID, + Type: rule.FileNameRegex.String(), + Data: ".*", + StorageName: "removed", + DirPath: "/old", + }).Error; err != nil { + t.Fatalf("failed to create rule: %v", err) + } + + if err := ClearStorageReferences(ctx, "removed"); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + + var gotUser User + if err := db.WithContext(ctx).First(&gotUser, user.ID).Error; err != nil { + t.Fatalf("failed to load user: %v", err) + } + if gotUser.DefaultStorage != "" || gotUser.DefaultDir != 0 || gotUser.Silent { + t.Fatalf("expected user defaults and silent to be cleared, got storage=%q dir=%d silent=%v", gotUser.DefaultStorage, gotUser.DefaultDir, gotUser.Silent) + } + + var dirCount int64 + if err := db.WithContext(ctx).Model(&Dir{}).Where("storage_name = ?", "removed").Count(&dirCount).Error; err != nil { + t.Fatalf("failed to count removed dirs: %v", err) + } + if dirCount != 0 { + t.Fatalf("expected removed storage dirs to be deleted, got %d", dirCount) + } + if err := db.WithContext(ctx).Model(&Dir{}).Where("storage_name = ?", "kept").Count(&dirCount).Error; err != nil { + t.Fatalf("failed to count kept dirs: %v", err) + } + if dirCount != 1 { + t.Fatalf("expected kept storage dir to remain, got %d", dirCount) + } + + var gotRule Rule + if err := db.WithContext(ctx).Where("user_id = ?", user.ID).First(&gotRule).Error; err != nil { + t.Fatalf("failed to load rule: %v", err) + } + if gotRule.StorageName != rule.RuleStorNameChosen { + t.Fatalf("expected rule storage to be reset to %q, got %q", rule.RuleStorNameChosen, gotRule.StorageName) + } +} + +func openTestDB(t *testing.T) { + t.Helper() + var err error + db, err = gorm.Open(GetDialect(t.TempDir()+"/saveany.db"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open test database: %v", err) + } + if err := db.AutoMigrate(&User{}, &Dir{}, &Rule{}, &WatchChat{}); err != nil { + t.Fatalf("failed to migrate test database: %v", err) + } + t.Cleanup(func() { db = nil }) +} diff --git a/database/user.go b/database/user.go index 7f631b5e..3b33366d 100644 --- a/database/user.go +++ b/database/user.go @@ -2,7 +2,10 @@ package database import ( "context" + "fmt" + "github.com/krau/SaveAny-Bot/pkg/rule" + "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -50,3 +53,39 @@ func GetUserByID(ctx context.Context, id uint) (*User, error) { Where("id = ?", id).First(&user).Error return &user, err } + +func ClearStorageReferences(ctx context.Context, storageName string) error { + if db == nil { + return fmt.Errorf("database is not initialized") + } + if storageName == "" { + return nil + } + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&User{}). + Where("default_storage = ?", storageName). + Updates(map[string]any{ + "default_storage": "", + "default_dir": 0, + "silent": false, + }).Error; err != nil { + return err + } + if err := tx.Model(&User{}). + Where("default_dir IN (?)", tx.Model(&Dir{}).Select("id").Where("storage_name = ?", storageName)). + Update("default_dir", 0).Error; err != nil { + return err + } + if err := tx.Unscoped(). + Where("storage_name = ?", storageName). + Delete(&Dir{}).Error; err != nil { + return err + } + if err := tx.Model(&Rule{}). + Where("storage_name = ?", storageName). + Update("storage_name", rule.RuleStorNameChosen).Error; err != nil { + return err + } + return nil + }) +} From de6b6371e1432fd0bc5d7215501646065c7d0965 Mon Sep 17 00:00:00 2001 From: liyw0205 Date: Mon, 25 May 2026 12:27:29 +0800 Subject: [PATCH 2/3] feat(config): restore web editor --- api/auth.go | 58 +- api/config_editor.go | 653 ++++++++++++++++ api/config_server.go | 71 ++ api/config_web.go | 6 + api/config_web.html | 1186 ++++++++++++++++++++++++++++++ api/factory.go | 103 ++- api/handlers.go | 178 ++++- api/progress.go | 95 ++- api/runtime.go | 165 +++++ api/server.go | 19 +- api/task_progress_adapters.go | 259 +++++++ api/types.go | 43 +- cmd/run.go | 18 +- cmd/web.go | 67 ++ common/logbuffer/logbuffer.go | 79 ++ config/edit.go | 743 +++++++++++++++++++ config/viper.go | 2 +- core/tasks/transfer/execute.go | 140 +++- core/tasks/transfer/task.go | 74 ++ core/tasks/transfer/taskinfo.go | 43 +- database/db.go | 26 +- database/storage_cleanup_test.go | 59 ++ database/user.go | 27 + storage/load.go | 76 +- 24 files changed, 4047 insertions(+), 143 deletions(-) create mode 100644 api/config_editor.go create mode 100644 api/config_server.go create mode 100644 api/config_web.go create mode 100644 api/config_web.html create mode 100644 api/runtime.go create mode 100644 api/task_progress_adapters.go create mode 100644 cmd/web.go create mode 100644 common/logbuffer/logbuffer.go create mode 100644 config/edit.go diff --git a/api/auth.go b/api/auth.go index 236bec3c..2a15d5a8 100644 --- a/api/auth.go +++ b/api/auth.go @@ -14,35 +14,63 @@ type tokenContextKey struct{} // AuthMiddleware 返回认证中间件 func AuthMiddleware() func(http.Handler) http.Handler { + return TokenAuthMiddleware(func() string { + return config.C().API.Token + }) +} + +func TokenAuthMiddleware(tokenProvider func() string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - cfg := config.C().API - - // 从请求头获取 token - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - WriteError(w, http.StatusUnauthorized, "unauthorized", "missing authorization header") + if isPublicConfigWebPath(r) { + next.ServeHTTP(w, r) return } - // 提取 Bearer token - parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { - WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid authorization header format") + token := tokenProvider() + if token == "" { + next.ServeHTTP(w, r) return } - token := parts[1] + requestToken := getRequestToken(r) + if requestToken == "" { + WriteError(w, http.StatusUnauthorized, "unauthorized", "missing authorization header") + return + } - // 验证 token - if subtle.ConstantTimeCompare([]byte(token), []byte(cfg.Token)) != 1 { + if subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) != 1 { WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid token") return } - // 将 token 添加到 context - ctx := context.WithValue(r.Context(), tokenContextKey{}, token) + ctx := context.WithValue(r.Context(), tokenContextKey{}, requestToken) next.ServeHTTP(w, r.WithContext(ctx)) }) } } + +func getRequestToken(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + return "" + } + return parts[1] + } + if token := r.Header.Get("X-API-Token"); token != "" { + return token + } + if cookie, err := r.Cookie("saveany_api_token"); err == nil { + return cookie.Value + } + return r.URL.Query().Get("token") +} + +func isPublicConfigWebPath(r *http.Request) bool { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + return false + } + return r.URL.Path == "/config" || strings.HasPrefix(r.URL.Path, "/config/") +} diff --git a/api/config_editor.go b/api/config_editor.go new file mode 100644 index 00000000..19832089 --- /dev/null +++ b/api/config_editor.go @@ -0,0 +1,653 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/common/logbuffer" + "github.com/krau/SaveAny-Bot/common/utils/netutil" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/pkg/rule" +) + +type ConfigEditor struct { + ctx context.Context + configPath string + autoOpenDatabase bool + mu sync.Mutex +} + +type ConfigEditorOption func(*ConfigEditor) + +func WithConfigEditorAutoOpenDatabase(autoOpen bool) ConfigEditorOption { + return func(editor *ConfigEditor) { + editor.autoOpenDatabase = autoOpen + } +} + +func RegisterConfigEditorRoutes(ctx context.Context, mux *http.ServeMux, configPath string, opts ...ConfigEditorOption) *ConfigEditor { + editor := &ConfigEditor{ + ctx: ctx, + configPath: config.ResolveConfigFilePath(configPath), + } + for _, opt := range opts { + opt(editor) + } + mux.HandleFunc("/config", editor.WebHandler) + mux.HandleFunc("/config/", editor.WebHandler) + mux.HandleFunc("/api/v1/config", editor.ConfigHandler) + mux.HandleFunc("/api/v1/config/apply", editor.ApplyConfigHandler) + mux.HandleFunc("/api/v1/config/schema", editor.SchemaHandler) + mux.HandleFunc("/api/v1/config/logs", editor.LogsHandler) + mux.HandleFunc("/api/v1/config/proxy-test", editor.ProxyTestHandler) + mux.HandleFunc("/api/v1/config/update-check", editor.UpdateCheckHandler) + mux.HandleFunc("/api/v1/config/rules/apply", editor.ApplyRuleHandler) + mux.HandleFunc("/api/v1/config/rules", editor.RulesHandler) + mux.HandleFunc("/api/v1/config/rules/", editor.RuleByIDHandler) + return editor +} + +func (e *ConfigEditor) WebHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + MethodNotAllowedHandler(w, r) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(configWebHTML) +} + +func (e *ConfigEditor) ConfigHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + file, err := config.LoadEditableConfig(e.configPath) + if err != nil { + WriteError(w, http.StatusInternalServerError, "config_load_failed", err.Error()) + return + } + WriteJSON(w, http.StatusOK, map[string]any{ + "path": file.Path, + "exists": file.Exists, + "config": file.Config, + "message": "保存后建议重启 bot 以完整应用存储、代理和高级配置。", + }) + case http.MethodPut: + cfg, err := decodeEditableConfigRequest(r) + if err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + e.mu.Lock() + defer e.mu.Unlock() + storageChanges := e.storageChanges(cfg) + if err := config.SaveEditableConfig(e.configPath, cfg); err != nil { + WriteError(w, http.StatusBadRequest, "config_save_failed", err.Error()) + return + } + warnings := e.applyStorageReferenceChanges(r.Context(), storageChanges) + WriteJSON(w, http.StatusOK, map[string]any{ + "message": "config saved", + "path": e.configPath, + "warnings": warnings, + }) + default: + MethodNotAllowedHandler(w, r) + } +} + +func (e *ConfigEditor) ApplyConfigHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + MethodNotAllowedHandler(w, r) + return + } + before := config.C() + e.mu.Lock() + defer e.mu.Unlock() + applyResult, err := e.reloadRuntimeConfig(r.Context(), before) + if err != nil { + WriteJSON(w, http.StatusAccepted, map[string]any{ + "message": "runtime reload failed; restart the bot after fixing the config", + "path": e.configPath, + "reload_error": err.Error(), + "runtime": applyResult, + }) + return + } + WriteJSON(w, http.StatusOK, map[string]any{ + "message": "config applied", + "path": e.configPath, + "runtime": applyResult, + }) +} + +func (e *ConfigEditor) LogsHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + MethodNotAllowedHandler(w, r) + return + } + limit := 200 + if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" { + parsed, err := strconv.Atoi(raw) + if err != nil || parsed < 1 { + WriteError(w, http.StatusBadRequest, "invalid_request", "limit must be a positive integer") + return + } + limit = parsed + } + WriteJSON(w, http.StatusOK, map[string]any{ + "lines": logbuffer.Default().Lines(limit), + }) +} + +func (e *ConfigEditor) UpdateCheckHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + MethodNotAllowedHandler(w, r) + return + } + resp, err := CheckUpdate() + if err != nil { + WriteError(w, http.StatusBadGateway, "update_check_failed", err.Error()) + return + } + WriteJSON(w, http.StatusOK, resp) +} + +func (e *ConfigEditor) ProxyTestHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + MethodNotAllowedHandler(w, r) + return + } + var req ProxyTestRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error()) + return + } + req.URL = strings.TrimSpace(req.URL) + if req.URL == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "proxy url is required") + return + } + target := strings.TrimSpace(req.Target) + if target == "" { + target = "https://api.telegram.org" + } + client, err := netutil.NewProxyHTTPClient(req.URL) + if err != nil { + WriteJSON(w, http.StatusOK, ProxyTestResponse{OK: false, Message: err.Error(), Target: target}) + return + } + client.Timeout = 10 * time.Second + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil) + if err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + start := time.Now() + resp, err := client.Do(httpReq) + elapsed := time.Since(start).Milliseconds() + if err != nil { + WriteJSON(w, http.StatusOK, ProxyTestResponse{OK: false, MS: elapsed, Message: err.Error(), Target: target}) + return + } + defer resp.Body.Close() + ok := resp.StatusCode >= 200 && resp.StatusCode < 500 + message := resp.Status + if ok { + message = "proxy reachable" + } + WriteJSON(w, http.StatusOK, ProxyTestResponse{OK: ok, MS: elapsed, Message: message, Target: target}) +} + +func (e *ConfigEditor) SchemaHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + MethodNotAllowedHandler(w, r) + return + } + ruleTypes := make([]string, 0, len(rule.Values())) + for _, ruleType := range rule.Values() { + ruleTypes = append(ruleTypes, ruleType.String()) + } + WriteJSON(w, http.StatusOK, map[string]any{ + "storage_types": config.StorageTypeNames(), + "storage_schemas": config.StorageSchemas(), + "rule_types": ruleTypes, + "rule_storage_chosen": rule.RuleStorNameChosen, + "rule_dir_new_album": rule.RuleDirPathNewForAlbum, + "database_ready": database.Ready(), + "config_path": e.configPath, + "config_reload_notice": "保存配置后建议重启 bot。", + }) +} + +func (e *ConfigEditor) RulesHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + if err := e.ensureDatabase(r.Context()); err != nil { + WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error()) + return + } + users, err := database.GetAllUsers(r.Context()) + if err != nil { + WriteError(w, http.StatusInternalServerError, "users_load_failed", err.Error()) + return + } + chatIDFilter, hasFilter, err := optionalChatID(r) + if err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + respUsers := make([]configUserResponse, 0, len(users)) + for _, user := range users { + if hasFilter && user.ChatID != chatIDFilter { + continue + } + respUsers = append(respUsers, convertConfigUser(user)) + } + WriteJSON(w, http.StatusOK, map[string]any{ + "users": respUsers, + "storages": configuredStorageNames(), + }) + case http.MethodPost: + if err := e.ensureDatabase(r.Context()); err != nil { + WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error()) + return + } + var req createRuleRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error()) + return + } + normalized, err := normalizeCreateRuleRequest(req) + if err != nil { + WriteError(w, http.StatusBadRequest, "invalid_rule", err.Error()) + return + } + user, err := database.GetUserByChatID(r.Context(), normalized.ChatID) + if err != nil { + WriteError(w, http.StatusNotFound, "user_not_found", fmt.Sprintf("user %d is not in database; save config and reload first", normalized.ChatID)) + return + } + newRule := &database.Rule{ + UserID: user.ID, + Type: normalized.Type, + Data: normalized.Data, + StorageName: normalized.StorageName, + DirPath: normalized.DirPath, + } + if err := database.CreateRule(r.Context(), newRule); err != nil { + WriteError(w, http.StatusInternalServerError, "rule_create_failed", err.Error()) + return + } + WriteJSON(w, http.StatusCreated, convertConfigRule(*newRule)) + default: + MethodNotAllowedHandler(w, r) + } +} + +func (e *ConfigEditor) RuleByIDHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + MethodNotAllowedHandler(w, r) + return + } + if err := e.ensureDatabase(r.Context()); err != nil { + WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error()) + return + } + id, err := ruleIDFromPath(r.URL.Path) + if err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + if err := database.DeleteRule(r.Context(), id); err != nil { + WriteError(w, http.StatusInternalServerError, "rule_delete_failed", err.Error()) + return + } + WriteJSON(w, http.StatusOK, map[string]any{"message": "rule deleted"}) +} + +func (e *ConfigEditor) ApplyRuleHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut && r.Method != http.MethodPatch { + MethodNotAllowedHandler(w, r) + return + } + if err := e.ensureDatabase(r.Context()); err != nil { + WriteError(w, http.StatusServiceUnavailable, "database_unavailable", err.Error()) + return + } + var req updateRuleModeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error()) + return + } + if req.ChatID == 0 { + WriteError(w, http.StatusBadRequest, "invalid_request", "chat_id is required") + return + } + if err := database.UpdateUserApplyRule(r.Context(), req.ChatID, req.ApplyRule); err != nil { + WriteError(w, http.StatusInternalServerError, "rule_mode_update_failed", err.Error()) + return + } + WriteJSON(w, http.StatusOK, map[string]any{ + "chat_id": req.ChatID, + "apply_rule": req.ApplyRule, + }) +} + +func (e *ConfigEditor) reloadRuntimeConfig(ctx context.Context, before config.Config) (RuntimeApplyResult, error) { + if err := config.Init(ctx, e.configPath); err != nil { + return RuntimeApplyResult{}, fmt.Errorf("saved config but failed to reload runtime config: %w", err) + } + result := RuntimeApplyResult{} + if e.autoOpenDatabase { + if !database.Ready() { + if err := database.Open(ctx); err != nil { + result.Warnings = append(result.Warnings, "数据库打开失败: "+err.Error()) + } + } + } + result.Warnings = append(result.Warnings, e.cleanupRemovedStorageReferences(ctx, removedRuntimeStorageNames(before, config.C()))...) + applied := ApplyRuntimeConfig(ctx, e.configPath, before) + applied.Warnings = append(result.Warnings, applied.Warnings...) + return applied, nil +} + +type storageReferenceChanges struct { + Renamed map[string]string + Removed []string +} + +func (e *ConfigEditor) storageChanges(next *config.EditableConfig) storageReferenceChanges { + current, err := config.LoadEditableConfig(e.configPath) + if err != nil { + return storageReferenceChanges{} + } + return diffEditableStorages(current.Config.Storages, next.Storages) +} + +func (e *ConfigEditor) applyStorageReferenceChanges(ctx context.Context, changes storageReferenceChanges) []string { + warnings := make([]string, 0) + if database.Ready() { + for oldName, newName := range changes.Renamed { + if err := database.RenameStorageReferences(ctx, oldName, newName); err != nil { + warnings = append(warnings, fmt.Sprintf("迁移存储 %s 到 %s 的用户关联失败: %v", oldName, newName, err)) + } + } + } + warnings = append(warnings, e.cleanupRemovedStorageReferences(ctx, changes.Removed)...) + return warnings +} + +func (e *ConfigEditor) cleanupRemovedStorageReferences(ctx context.Context, names []string) []string { + names = compactStrings(names) + if len(names) == 0 || !database.Ready() { + return nil + } + warnings := make([]string, 0) + for _, name := range names { + if err := database.ClearStorageReferences(ctx, name); err != nil { + warnings = append(warnings, fmt.Sprintf("清理存储 %s 的用户关联失败: %v", name, err)) + } + } + return warnings +} + +func diffEditableStorages(before, after []map[string]any) storageReferenceChanges { + beforeNames := editableStorageNames(before) + afterNames := editableStorageNames(after) + changes := storageReferenceChanges{ + Renamed: make(map[string]string), + Removed: make([]string, 0), + } + for i := range before { + if i >= len(after) { + continue + } + oldName := strings.TrimSpace(fmt.Sprint(before[i]["name"])) + newName := strings.TrimSpace(fmt.Sprint(after[i]["name"])) + if oldName == "" || newName == "" || oldName == newName { + continue + } + if _, oldStillExists := afterNames[oldName]; oldStillExists { + continue + } + if _, newExistedBefore := beforeNames[newName]; newExistedBefore { + continue + } + changes.Renamed[oldName] = newName + } + for oldName := range beforeNames { + if _, ok := afterNames[oldName]; ok { + continue + } + if _, renamed := changes.Renamed[oldName]; renamed { + continue + } + changes.Removed = append(changes.Removed, oldName) + } + return changes +} + +func removedRuntimeStorageNames(before, after config.Config) []string { + beforeNames := make(map[string]struct{}, len(before.Storages)) + afterNames := make(map[string]struct{}, len(after.Storages)) + for _, storage := range before.Storages { + beforeNames[storage.GetName()] = struct{}{} + } + for _, storage := range after.Storages { + afterNames[storage.GetName()] = struct{}{} + } + return missingStorageNames(beforeNames, afterNames) +} + +func editableStorageNames(storages []map[string]any) map[string]struct{} { + names := make(map[string]struct{}, len(storages)) + for _, storage := range storages { + name := strings.TrimSpace(fmt.Sprint(storage["name"])) + if name != "" { + names[name] = struct{}{} + } + } + return names +} + +func missingStorageNames(before, after map[string]struct{}) []string { + missing := make([]string, 0) + for name := range before { + if _, ok := after[name]; !ok { + missing = append(missing, name) + } + } + return missing +} + +func (e *ConfigEditor) ensureDatabase(ctx context.Context) error { + if database.Ready() { + return nil + } + if !e.autoOpenDatabase { + return fmt.Errorf("database is not initialized") + } + if err := config.Init(ctx, e.configPath); err != nil { + return fmt.Errorf("load config before opening database: %w", err) + } + return database.Open(ctx) +} + +func decodeEditableConfigRequest(r *http.Request) (*config.EditableConfig, error) { + var raw json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&raw); err != nil { + return nil, fmt.Errorf("failed to decode request body: %w", err) + } + var wrapped struct { + Config *config.EditableConfig `json:"config"` + } + if err := json.Unmarshal(raw, &wrapped); err == nil && wrapped.Config != nil { + return wrapped.Config, nil + } + var cfg config.EditableConfig + if err := json.Unmarshal(raw, &cfg); err != nil { + return nil, fmt.Errorf("failed to decode config: %w", err) + } + return &cfg, nil +} + +type configUserResponse struct { + ID uint `json:"id"` + ChatID int64 `json:"chat_id"` + ApplyRule bool `json:"apply_rule"` + DefaultStorage string `json:"default_storage"` + DefaultDir uint `json:"default_dir"` + Rules []configRuleResponse `json:"rules"` +} + +type configRuleResponse struct { + ID uint `json:"id"` + Type string `json:"type"` + Data string `json:"data"` + StorageName string `json:"storage_name"` + DirPath string `json:"dir_path"` +} + +type createRuleRequest struct { + ChatID int64 `json:"chat_id"` + Type string `json:"type"` + Data string `json:"data"` + StorageName string `json:"storage_name"` + DirPath string `json:"dir_path"` +} + +type updateRuleModeRequest struct { + ChatID int64 `json:"chat_id"` + ApplyRule bool `json:"apply_rule"` +} + +func convertConfigUser(user database.User) configUserResponse { + rules := make([]configRuleResponse, 0, len(user.Rules)) + for _, userRule := range user.Rules { + rules = append(rules, convertConfigRule(userRule)) + } + return configUserResponse{ + ID: user.ID, + ChatID: user.ChatID, + ApplyRule: user.ApplyRule, + DefaultStorage: user.DefaultStorage, + DefaultDir: user.DefaultDir, + Rules: rules, + } +} + +func convertConfigRule(userRule database.Rule) configRuleResponse { + return configRuleResponse{ + ID: userRule.ID, + Type: userRule.Type, + Data: userRule.Data, + StorageName: userRule.StorageName, + DirPath: userRule.DirPath, + } +} + +func optionalChatID(r *http.Request) (int64, bool, error) { + raw := strings.TrimSpace(r.URL.Query().Get("chat_id")) + if raw == "" { + return 0, false, nil + } + chatID, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return 0, false, fmt.Errorf("invalid chat_id: %w", err) + } + return chatID, true, nil +} + +func ruleIDFromPath(path string) (uint, error) { + raw := strings.TrimPrefix(path, "/api/v1/config/rules/") + if raw == "" || strings.Contains(raw, "/") { + return 0, fmt.Errorf("rule id is required") + } + id, err := strconv.ParseUint(raw, 10, 64) + if err != nil || id == 0 { + return 0, fmt.Errorf("invalid rule id") + } + return uint(id), nil +} + +func normalizeCreateRuleRequest(req createRuleRequest) (createRuleRequest, error) { + req.Type = strings.ToUpper(strings.TrimSpace(req.Type)) + req.Data = strings.TrimSpace(req.Data) + req.StorageName = strings.TrimSpace(req.StorageName) + req.DirPath = strings.TrimSpace(req.DirPath) + if req.ChatID == 0 { + return req, fmt.Errorf("chat_id is required") + } + if req.Type == "" { + return req, fmt.Errorf("type is required") + } + if req.Data == "" { + return req, fmt.Errorf("data is required") + } + if req.StorageName == "" { + return req, fmt.Errorf("storage_name is required") + } + if req.DirPath == "" { + return req, fmt.Errorf("dir_path is required") + } + validType := false + for _, value := range rule.Values() { + if req.Type == value.String() { + validType = true + break + } + } + if !validType { + return req, fmt.Errorf("invalid rule type: %s", req.Type) + } + switch req.Type { + case rule.FileNameRegex.String(), rule.MessageRegex.String(): + if _, err := regexp.Compile(req.Data); err != nil { + return req, fmt.Errorf("invalid regex: %w", err) + } + case rule.IsAlbum.String(): + if _, err := strconv.ParseBool(req.Data); err != nil { + return req, fmt.Errorf("IS-ALBUM data must be true or false") + } + } + if req.StorageName != rule.RuleStorNameChosen && !configuredStorageNameExists(req.StorageName) { + return req, fmt.Errorf("unknown storage_name: %s", req.StorageName) + } + return req, nil +} + +func configuredStorageNames() []string { + cfg, err := config.LoadEditableConfig(config.ConfigFileUsed()) + if err != nil { + log.Warnf("failed to load configured storage names: %v", err) + return nil + } + names := make([]string, 0, len(cfg.Config.Storages)) + for _, storage := range cfg.Config.Storages { + name := strings.TrimSpace(fmt.Sprint(storage["name"])) + if name != "" { + names = append(names, name) + } + } + return names +} + +func configuredStorageNameExists(name string) bool { + for _, configuredName := range configuredStorageNames() { + if configuredName == name { + return true + } + } + return false +} diff --git a/api/config_server.go b/api/config_server.go new file mode 100644 index 00000000..ea25d20b --- /dev/null +++ b/api/config_server.go @@ -0,0 +1,71 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/config" +) + +type ConfigWebServerOptions struct { + ConfigPath string + Host string + Port int + Token string +} + +func NewConfigWebServer(ctx context.Context, opts ConfigWebServerOptions) *http.Server { + mux := http.NewServeMux() + RegisterConfigEditorRoutes(ctx, mux, opts.ConfigPath, WithConfigEditorAutoOpenDatabase(true)) + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + http.Redirect(w, r, "/config", http.StatusFound) + return + } + NotFoundHandler(w, r) + }) + + var handler http.Handler = mux + handler = TokenAuthMiddleware(func() string { + if opts.Token != "" { + return opts.Token + } + return config.C().API.Token + })(handler) + handler = loggingMiddleware(handler) + handler = recoveryMiddleware(handler) + + return &http.Server{ + Addr: fmt.Sprintf("%s:%d", opts.Host, opts.Port), + Handler: handler, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } +} + +func StartConfigWebServer(ctx context.Context, opts ConfigWebServerOptions) (*http.Server, error) { + server := NewConfigWebServer(ctx, opts) + logger := log.FromContext(ctx).With("module", "config-web") + logger.Infof("Starting config web server on %s", server.Addr) + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Errorf("Config web server error: %v", err) + } + }() + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + logger.Errorf("Config web server shutdown error: %v", err) + } + }() + return server, nil +} diff --git a/api/config_web.go b/api/config_web.go new file mode 100644 index 00000000..e296bb7e --- /dev/null +++ b/api/config_web.go @@ -0,0 +1,6 @@ +package api + +import _ "embed" + +//go:embed config_web.html +var configWebHTML []byte diff --git a/api/config_web.html b/api/config_web.html new file mode 100644 index 00000000..47658163 --- /dev/null +++ b/api/config_web.html @@ -0,0 +1,1186 @@ + + + + + + SaveAny-Bot 配置 + + + + + +
+ +
+
+
+

配置

+
config.toml
+
+
+ + +
+
+
+
+

验证

+
+
+ + +
+
+
+
+
+
+
+
+
+
+
+ + + + diff --git a/api/factory.go b/api/factory.go index 4c00a6bd..51b9c84e 100644 --- a/api/factory.go +++ b/api/factory.go @@ -38,8 +38,8 @@ func NewTaskFactory(ctx context.Context) *TaskFactory { // CreateTask 创建任务 func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, error) { // 验证存储 - stor, ok := storage.Storages[req.Storage] - if !ok { + stor, err := storage.GetStorageByName(f.ctx, req.Storage) + if err != nil { return nil, fmt.Errorf("storage not found: %s", req.Storage) } @@ -66,19 +66,6 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e } } -func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error { - taskID := task.TaskID() - RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook) - - err := core.AddTask(f.ctx, NewExecutableWrapper(task)) - if err != nil { - DeleteTask(taskID) - return fmt.Errorf("failed to add task: %w", err) - } - - return nil -} - // createDirectLinksTask 创建直链下载任务 func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { var params DirectLinksParams @@ -90,11 +77,12 @@ func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, return nil, fmt.Errorf("no URLs provided") } - task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil) + task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, newDirectLinksAPIProgress(taskID)) + RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req) - err := f.registerAndEnqueueTask(task, tasktype.TaskTypeDirectlinks, req.Storage, req.Path, req.Webhook) - if err != nil { - return nil, err + if err := core.AddTask(f.ctx, task); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) } return &CreateTaskResponse{ @@ -116,11 +104,12 @@ func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *C return nil, fmt.Errorf("no URLs provided") } - task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil) + task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, newYTDLPAPIProgress(taskID)) + RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req) - err := f.registerAndEnqueueTask(task, tasktype.TaskTypeYtdlp, req.Storage, req.Path, req.Webhook) - if err != nil { - return nil, err + if err := core.AddTask(f.ctx, task); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) } return &CreateTaskResponse{ @@ -159,11 +148,12 @@ func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *C return nil, fmt.Errorf("failed to add aria2 task: %w", err) } - task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil) + task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, newAria2APIProgress(taskID)) + RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req) - err = f.registerAndEnqueueTask(task, tasktype.TaskTypeAria2, req.Storage, req.Path, req.Webhook) - if err != nil { - return nil, err + if err := core.AddTask(f.ctx, task); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) } return &CreateTaskResponse{ @@ -204,11 +194,12 @@ func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req * return nil, fmt.Errorf("failed to parse URL: %w", err) } - task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil) + task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, newParsedAPIProgress(taskID)) + RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req) - err = f.registerAndEnqueueTask(task, tasktype.TaskTypeParseditem, req.Storage, req.Path, req.Webhook) - if err != nil { - return nil, err + if err := core.AddTask(f.ctx, task); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) } return &CreateTaskResponse{ @@ -240,15 +231,17 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req return nil, fmt.Errorf("no files found in provided links") } - var task core.Executable - if len(files) == 1 { // 单个文件任务 - tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil) + tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, newTFileAPIProgress(taskID)) if err != nil { return nil, fmt.Errorf("failed to create tfile task: %w", err) } - task = tfileTask + RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, tfileTask.Title(), req.Webhook, req) + if err := core.AddTask(f.ctx, tfileTask); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) + } } else { // 批量文件任务 elems := make([]batchtfile.TaskElement, 0, len(files)) @@ -260,12 +253,12 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req elems = append(elems, *elem) } - task = batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true) - } - - err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTgfiles, req.Storage, req.Path, req.Webhook) - if err != nil { - return nil, err + task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, newBatchTFileAPIProgress(taskID), true) + RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req) + if err := core.AddTask(f.ctx, task); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) + } } return &CreateTaskResponse{ @@ -298,11 +291,12 @@ func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req } client := telegraph.NewClient() - task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil) + task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, newTelegraphAPIProgress(taskID)) + RegisterTask(taskID, req.Type.String(), req.Storage, req.Path, task.Title(), req.Webhook, req) - err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTphpics, req.Storage, req.Path, req.Webhook) - if err != nil { - return nil, err + if err := core.AddTask(f.ctx, task); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) } return &CreateTaskResponse{ @@ -321,13 +315,13 @@ func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req } // 验证源存储和目标存储 - sourceStor, ok := storage.Storages[params.SourceStorage] - if !ok { + sourceStor, err := storage.GetStorageByName(f.ctx, params.SourceStorage) + if err != nil { return nil, fmt.Errorf("source storage not found: %s", params.SourceStorage) } - targetStor, ok := storage.Storages[params.TargetStorage] - if !ok { + targetStor, err := storage.GetStorageByName(f.ctx, params.TargetStorage) + if err != nil { return nil, fmt.Errorf("target storage not found: %s", params.TargetStorage) } @@ -360,11 +354,14 @@ func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req elems = append(elems, *elem) } - task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true) + task := transfer.NewTransferTask(taskID, f.ctx, elems, newTransferAPIProgress(taskID), true) + info := RegisterTask(taskID, req.Type.String(), params.TargetStorage, params.TargetPath, task.Title(), req.Webhook, req) + info.SetTransferMeta(params.SourceStorage, params.SourcePath, params.TargetStorage, params.TargetPath) + RegisterTaskControl(taskID, task) - err = f.registerAndEnqueueTask(task, tasktype.TaskTypeTransfer, params.TargetStorage, params.TargetPath, req.Webhook) - if err != nil { - return nil, err + if err := core.AddTask(f.ctx, task); err != nil { + DeleteTask(taskID) + return nil, fmt.Errorf("failed to add task: %w", err) } return &CreateTaskResponse{ diff --git a/api/handlers.go b/api/handlers.go index 81946d02..db321d7a 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "strings" + "sync/atomic" "github.com/krau/SaveAny-Bot/core" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" @@ -117,6 +118,15 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) { return } + if r.URL.Query().Get("record") == "1" || r.URL.Query().Get("record") == "true" { + if task.Status == TaskStatusQueued || task.Status == TaskStatusRunning { + _ = core.CancelTask(r.Context(), taskID) + } + DeleteTask(taskID) + WriteJSON(w, http.StatusOK, map[string]string{"message": "task deleted successfully"}) + return + } + // 取消任务 if err := core.CancelTask(r.Context(), taskID); err != nil { WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error()) @@ -127,6 +137,106 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) { WriteJSON(w, http.StatusOK, map[string]string{"message": "task cancelled successfully"}) } +func (h *Handlers) PauseTaskHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + MethodNotAllowedHandler(w, r) + return + } + taskID, _ := extractTaskIDAndAction(r.URL.Path) + if taskID == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required") + return + } + task, ok := GetTask(taskID) + if !ok { + WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID) + return + } + if task.Status == TaskStatusQueued || task.Status == TaskStatusRunning { + if err := core.CancelTask(r.Context(), taskID); err != nil { + WriteError(w, http.StatusInternalServerError, "pause_failed", "failed to pause task: "+err.Error()) + return + } + } + task.UpdateStatus(TaskStatusPaused) + task.UpdatePhase("paused") + WriteJSON(w, http.StatusOK, map[string]string{"message": "task paused successfully"}) +} + +func (h *Handlers) RetryTaskHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + MethodNotAllowedHandler(w, r) + return + } + taskID, _ := extractTaskIDAndAction(r.URL.Path) + if taskID == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required") + return + } + task, ok := GetTask(taskID) + if !ok { + WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID) + return + } + if task.Request == nil { + WriteError(w, http.StatusBadRequest, "retry_unavailable", "task request data is not available") + return + } + if task.Status != TaskStatusFailed && task.Status != TaskStatusCancelled && task.Status != TaskStatusPaused { + WriteError(w, http.StatusBadRequest, "retry_unavailable", "only failed, cancelled, or paused tasks can be retried") + return + } + resp, err := h.factory.CreateTask(cloneCreateTaskRequest(task.Request)) + if err != nil { + WriteError(w, http.StatusBadRequest, "retry_failed", err.Error()) + return + } + WriteJSON(w, http.StatusCreated, resp) +} + +func (h *Handlers) UpdateTaskPathHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut && r.Method != http.MethodPatch { + MethodNotAllowedHandler(w, r) + return + } + taskID, _ := extractTaskIDAndAction(r.URL.Path) + if taskID == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required") + return + } + task, ok := GetTask(taskID) + if !ok { + WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID) + return + } + if task.Type != string(tasktype.TaskTypeTransfer) { + WriteError(w, http.StatusBadRequest, "unsupported_task", "only transfer task path can be updated") + return + } + var req UpdateTaskPathRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error()) + return + } + req.Path = strings.TrimSpace(req.Path) + if req.Path == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "path is required") + return + } + control, ok := GetTaskControl(taskID) + if !ok { + WriteError(w, http.StatusConflict, "control_unavailable", "task runtime control is not available") + return + } + control.UpdateTargetPath(req.Path) + task.UpdateTargetPath(req.Path) + updateStoredTransferTargetPath(task, req.Path) + WriteJSON(w, http.StatusOK, map[string]any{ + "message": "task path updated", + "target_path": req.Path, + }) +} + // ListStoragesHandler 列出存储处理器 func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { @@ -134,8 +244,9 @@ func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) { return } - storages := make([]StorageInfo, 0, len(storage.Storages)) - for name, stor := range storage.Storages { + snapshot := storage.Snapshot() + storages := make([]StorageInfo, 0, len(snapshot)) + for name, stor := range snapshot { storages = append(storages, StorageInfo{ Name: name, Type: string(stor.Type()), @@ -177,33 +288,54 @@ func (h *Handlers) HealthCheckHandler(w http.ResponseWriter, r *http.Request) { // extractTaskIDFromPath 从路径中提取任务 ID // 路径格式: /api/v1/tasks/:id func extractTaskIDFromPath(path string) string { + taskID, _ := extractTaskIDAndAction(path) + return taskID +} + +func extractTaskIDAndAction(path string) (string, string) { parts := strings.Split(strings.Trim(path, "/"), "/") if len(parts) < 4 { - return "" + return "", "" + } + action := "" + if len(parts) > 4 { + action = parts[4] } - return parts[3] + return parts[3], action } // convertTaskProgressToResponse 将任务进度转换为响应格式 func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse { resp := TaskInfoResponse{ - TaskID: task.TaskID, - Type: tasktype.TaskType(task.Type), - Status: task.Status, - Title: task.Title, - Storage: task.Storage, - Path: task.Path, - Error: task.Error, - CreatedAt: task.CreatedAt, - UpdatedAt: task.UpdatedAt, + TaskID: task.TaskID, + Type: tasktype.TaskType(task.Type), + Status: task.Status, + Title: task.Title, + Storage: task.Storage, + Path: task.Path, + SourceStorage: task.SourceStorage, + SourcePath: task.SourcePath, + TargetStorage: task.TargetStorage, + TargetPath: task.TargetPath, + Phase: task.Phase, + Error: task.Error, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, } // 计算进度 if task.TotalBytes > 0 { - percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes) + downloaded := atomic.LoadInt64(&task.DownloadedBytes) + uploaded := atomic.LoadInt64(&task.UploadedBytes) + current := downloaded + if uploaded > 0 || task.Phase == "uploading" { + current = uploaded + } + percent := float64(current) * 100 / float64(task.TotalBytes) resp.Progress = &TaskProgress{ TotalBytes: task.TotalBytes, - DownloadedBytes: task.DownloadedBytes, + DownloadedBytes: downloaded, + UploadedBytes: uploaded, Percent: percent, } } @@ -211,6 +343,22 @@ func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse { return resp } +func updateStoredTransferTargetPath(info *TaskProgressInfo, targetPath string) { + if info.Request == nil || info.Request.Params == nil { + return + } + var params TransferParams + if err := json.Unmarshal(info.Request.Params, ¶ms); err != nil { + return + } + params.TargetPath = targetPath + data, err := json.Marshal(params) + if err != nil { + return + } + info.Request.Params = data +} + // NotFoundHandler 404 处理器 func NotFoundHandler(w http.ResponseWriter, r *http.Request) { WriteError(w, http.StatusNotFound, "not_found", "endpoint not found: "+r.URL.Path) diff --git a/api/progress.go b/api/progress.go index adac3712..71d103a9 100644 --- a/api/progress.go +++ b/api/progress.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "sync" "sync/atomic" "time" @@ -14,28 +15,41 @@ type TaskProgressInfo struct { Title string TotalBytes int64 DownloadedBytes int64 + UploadedBytes int64 TotalFiles int DownloadedFiles int Storage string Path string + SourceStorage string + SourcePath string + TargetStorage string + TargetPath string + Phase string Error string CreatedAt time.Time UpdatedAt time.Time Webhook string + Request *CreateTaskRequest } // progressStore 存储所有 API 任务的进度信息 type progressStore struct { - mu sync.RWMutex - tasks map[string]*TaskProgressInfo + mu sync.RWMutex + tasks map[string]*TaskProgressInfo + controls map[string]TaskControl } var store = &progressStore{ - tasks: make(map[string]*TaskProgressInfo), + tasks: make(map[string]*TaskProgressInfo), + controls: make(map[string]TaskControl), +} + +type TaskControl interface { + UpdateTargetPath(path string) } // RegisterTask 注册一个新的 API 任务 -func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo { +func RegisterTask(taskID, taskType, storage, path, title, webhook string, reqs ...*CreateTaskRequest) *TaskProgressInfo { info := &TaskProgressInfo{ TaskID: taskID, Type: taskType, @@ -43,10 +57,14 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP Title: title, Storage: storage, Path: path, + Phase: "queued", CreatedAt: time.Now(), UpdatedAt: time.Now(), Webhook: webhook, } + if len(reqs) > 0 && reqs[0] != nil { + info.Request = cloneCreateTaskRequest(reqs[0]) + } store.mu.Lock() store.tasks[taskID] = info @@ -80,6 +98,23 @@ func DeleteTask(taskID string) { store.mu.Lock() defer store.mu.Unlock() delete(store.tasks, taskID) + delete(store.controls, taskID) +} + +func RegisterTaskControl(taskID string, control TaskControl) { + if control == nil { + return + } + store.mu.Lock() + defer store.mu.Unlock() + store.controls[taskID] = control +} + +func GetTaskControl(taskID string) (TaskControl, bool) { + store.mu.RLock() + defer store.mu.RUnlock() + control, ok := store.controls[taskID] + return control, ok } // UpdateStatus 更新任务状态 @@ -92,6 +127,44 @@ func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) { func (t *TaskProgressInfo) SetError(err string) { t.Error = err t.Status = TaskStatusFailed + t.Phase = "failed" + t.UpdatedAt = time.Now() +} + +func (t *TaskProgressInfo) SetTransferMeta(sourceStorage, sourcePath, targetStorage, targetPath string) { + t.SourceStorage = sourceStorage + t.SourcePath = sourcePath + t.TargetStorage = targetStorage + t.TargetPath = targetPath + t.Path = targetPath + t.UpdatedAt = time.Now() +} + +func (t *TaskProgressInfo) UpdatePhase(phase string) { + t.Phase = phase + t.UpdatedAt = time.Now() +} + +func (t *TaskProgressInfo) UpdateDownloadProgress(downloadedBytes, totalBytes int64) { + if totalBytes > 0 { + t.TotalBytes = totalBytes + } + atomic.StoreInt64(&t.DownloadedBytes, downloadedBytes) + t.UpdatedAt = time.Now() +} + +func (t *TaskProgressInfo) UpdateUploadProgress(uploadedBytes, totalBytes int64) { + if totalBytes > 0 { + t.TotalBytes = totalBytes + } + atomic.StoreInt64(&t.UploadedBytes, uploadedBytes) + t.UpdatedAt = time.Now() +} + +func (t *TaskProgressInfo) UpdateTargetPath(path string) { + t.TargetPath = path + t.Path = path + t.UploadedBytes = 0 t.UpdatedAt = time.Now() } @@ -111,6 +184,7 @@ func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) { p.info.Status = TaskStatusRunning p.info.TotalBytes = totalBytes p.info.TotalFiles = totalFiles + p.info.Phase = "running" p.info.UpdatedAt = time.Now() } @@ -126,8 +200,10 @@ func (p *ProgressTracker) OnDone(err error) { if err != nil { p.info.Status = TaskStatusFailed p.info.Error = err.Error() + p.info.Phase = "failed" } else { p.info.Status = TaskStatusCompleted + p.info.Phase = "completed" } p.info.UpdatedAt = time.Now() } @@ -148,3 +224,14 @@ func (p *ProgressTracker) UpdateProgressFiles(files int) { p.info.DownloadedFiles = files p.info.UpdatedAt = time.Now() } + +func cloneCreateTaskRequest(req *CreateTaskRequest) *CreateTaskRequest { + if req == nil { + return nil + } + cloned := *req + if req.Params != nil { + cloned.Params = append(json.RawMessage(nil), req.Params...) + } + return &cloned +} diff --git a/api/runtime.go b/api/runtime.go new file mode 100644 index 00000000..8ce7c285 --- /dev/null +++ b/api/runtime.go @@ -0,0 +1,165 @@ +package api + +import ( + "context" + "fmt" + "runtime" + "slices" + "strings" + + "github.com/blang/semver" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/database" + "github.com/krau/SaveAny-Bot/storage" + "github.com/unvgo/ghselfupdate" +) + +type RuntimeApplyResult struct { + Applied []string `json:"applied"` + RestartRequired []string `json:"restart_required"` + Warnings []string `json:"warnings,omitempty"` +} + +type UpdateCheckResponse struct { + CurrentVersion string `json:"current_version"` + LatestVersion string `json:"latest_version,omitempty"` + LatestName string `json:"latest_name,omitempty"` + Found bool `json:"found"` + HasUpdate bool `json:"has_update"` + ReleaseURL string `json:"release_url,omitempty"` + ReleaseNotes string `json:"release_notes,omitempty"` + Platform string `json:"platform"` + Message string `json:"message,omitempty"` +} + +func ApplyRuntimeConfig(ctx context.Context, configPath string, before config.Config) RuntimeApplyResult { + result := RuntimeApplyResult{ + Applied: []string{"配置文件"}, + } + after := config.C() + + if err := applyLogLevel(after.Log.Level); err != nil { + result.Warnings = append(result.Warnings, err.Error()) + } else { + result.Applied = append(result.Applied, "日志级别") + } + + i18n.Init(after.Lang) + result.Applied = append(result.Applied, "语言") + + if database.Ready() { + if dbPathChanged(before, after) { + result.RestartRequired = append(result.RestartRequired, "数据库路径") + } else if err := database.SyncUsers(ctx); err != nil { + result.Warnings = append(result.Warnings, "同步用户失败: "+err.Error()) + } else { + result.Applied = append(result.Applied, "用户配置") + } + } + + if err := storage.ReloadStorages(ctx); err != nil { + result.Warnings = append(result.Warnings, "部分存储加载失败: "+err.Error()) + } + result.Applied = append(result.Applied, "存储配置") + + result.Applied = append(result.Applied, "全局 HTTP 代理", "API Token") + result.RestartRequired = append(result.RestartRequired, restartRequiredChanges(before, after)...) + result.RestartRequired = compactStrings(result.RestartRequired) + result.Applied = compactStrings(result.Applied) + return result +} + +func CheckUpdate() (*UpdateCheckResponse, error) { + resp := &UpdateCheckResponse{ + CurrentVersion: config.Version, + Platform: runtime.GOOS + "/" + runtime.GOARCH, + } + current, err := semver.ParseTolerant(strings.TrimPrefix(config.Version, "v")) + if err != nil { + resp.Message = "当前版本不是发布版本,无法准确比较。" + } + + latest, found, err := ghselfupdate.DetectLatest(config.GitRepo) + if err != nil { + return nil, err + } + resp.Found = found + if !found { + resp.Message = "没有找到 release。" + return resp, nil + } + + resp.LatestVersion = latest.Version.String() + resp.LatestName = latest.Name + resp.ReleaseNotes = latest.ReleaseNotes + resp.ReleaseURL = latest.URL + if err == nil { + resp.HasUpdate = latest.Version.GT(current) + if latest.Version.Equals(current) || latest.Version.LT(current) { + resp.Message = "当前已经是最新版本。" + } + } + return resp, nil +} + +func applyLogLevel(level string) error { + parsed, err := log.ParseLevel(strings.TrimSpace(level)) + if err != nil { + return fmt.Errorf("日志级别无效: %w", err) + } + log.Default().SetLevel(parsed) + return nil +} + +func restartRequiredChanges(before, after config.Config) []string { + var fields []string + if before.Lang == "" && before.DB.Path == "" { + return fields + } + if before.Workers != 0 && before.Workers != after.Workers { + fields = append(fields, "Workers 并发数") + } + if before.API.Enable != after.API.Enable || before.API.Host != after.API.Host || before.API.Port != after.API.Port { + fields = append(fields, "HTTP API 监听地址/端口") + } + if before.Telegram.Token != after.Telegram.Token || + before.Telegram.AppID != after.Telegram.AppID || + before.Telegram.AppHash != after.Telegram.AppHash || + before.Telegram.Proxy != after.Telegram.Proxy || + before.Telegram.Userbot != after.Telegram.Userbot { + fields = append(fields, "Telegram Bot/Userbot 连接") + } + if before.Parser.PluginEnable != after.Parser.PluginEnable || !slices.Equal(before.Parser.PluginDirs, after.Parser.PluginDirs) { + fields = append(fields, "Parser 插件加载") + } + if before.Cache != after.Cache { + fields = append(fields, "缓存参数") + } + if before.Hook != after.Hook { + fields = append(fields, "任务 Hook") + } + return fields +} + +func dbPathChanged(before, after config.Config) bool { + return before.DB.Path != "" && before.DB.Path != after.DB.Path +} + +func compactStrings(values []string) []string { + seen := make(map[string]struct{}, len(values)) + out := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} diff --git a/api/server.go b/api/server.go index be8b6603..1fd7e6fc 100644 --- a/api/server.go +++ b/api/server.go @@ -41,10 +41,26 @@ func NewServer(ctx context.Context) *Server { } }) mux.HandleFunc("/api/v1/tasks/", func(w http.ResponseWriter, r *http.Request) { + _, action := extractTaskIDAndAction(r.URL.Path) + switch action { + case "pause": + handlers.PauseTaskHandler(w, r) + return + case "retry": + handlers.RetryTaskHandler(w, r) + return + case "path": + handlers.UpdateTaskPathHandler(w, r) + return + } // 根据方法和路径分发 switch r.Method { case http.MethodGet: - handlers.GetTaskHandler(w, r) + if r.URL.Path == "/api/v1/tasks" { + handlers.ListTasksHandler(w, r) + } else { + handlers.GetTaskHandler(w, r) + } case http.MethodDelete: handlers.CancelTaskHandler(w, r) default: @@ -53,6 +69,7 @@ func NewServer(ctx context.Context) *Server { }) mux.HandleFunc("/api/v1/storages", handlers.ListStoragesHandler) mux.HandleFunc("/api/v1/task-types", handlers.GetTaskTypesHandler) + RegisterConfigEditorRoutes(ctx, mux, config.ConfigFileUsed()) // 404 处理 mux.HandleFunc("/", NotFoundHandler) diff --git a/api/task_progress_adapters.go b/api/task_progress_adapters.go new file mode 100644 index 00000000..8f19ec66 --- /dev/null +++ b/api/task_progress_adapters.go @@ -0,0 +1,259 @@ +package api + +import ( + "context" + "errors" + "strconv" + "strings" + + "github.com/krau/SaveAny-Bot/core/tasks/aria2dl" + "github.com/krau/SaveAny-Bot/core/tasks/batchtfile" + "github.com/krau/SaveAny-Bot/core/tasks/directlinks" + "github.com/krau/SaveAny-Bot/core/tasks/parsed" + tphtask "github.com/krau/SaveAny-Bot/core/tasks/telegraph" + "github.com/krau/SaveAny-Bot/core/tasks/tfile" + "github.com/krau/SaveAny-Bot/core/tasks/transfer" + "github.com/krau/SaveAny-Bot/core/tasks/ytdlp" + "github.com/krau/SaveAny-Bot/pkg/aria2" +) + +type directLinksAPIProgress struct{ taskID string } + +func newDirectLinksAPIProgress(taskID string) directlinks.ProgressTracker { + return &directLinksAPIProgress{taskID: taskID} +} + +func (p *directLinksAPIProgress) OnStart(ctx context.Context, info directlinks.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdateStatus(TaskStatusRunning) + task.UpdatePhase("downloading") + task.TotalFiles = info.TotalFiles() + } +} + +func (p *directLinksAPIProgress) OnProgress(ctx context.Context, info directlinks.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdatePhase("downloading") + task.UpdateDownloadProgress(info.DownloadedBytes(), info.TotalBytes()) + task.TotalFiles = info.TotalFiles() + } +} + +func (p *directLinksAPIProgress) OnDone(ctx context.Context, info directlinks.TaskInfo, err error) { + finishAPITask(p.taskID, err) +} + +type ytdlpAPIProgress struct{ taskID string } + +func newYTDLPAPIProgress(taskID string) ytdlp.ProgressTracker { + return &ytdlpAPIProgress{taskID: taskID} +} + +func (p *ytdlpAPIProgress) OnStart(ctx context.Context, task *ytdlp.Task) { + if info, ok := GetTask(p.taskID); ok { + info.UpdateStatus(TaskStatusRunning) + info.UpdatePhase("downloading") + } +} + +func (p *ytdlpAPIProgress) OnProgress(ctx context.Context, task *ytdlp.Task, status string) { + if info, ok := GetTask(p.taskID); ok { + if strings.HasPrefix(strings.ToLower(status), "transferred") { + info.UpdatePhase("uploading") + } else { + info.UpdatePhase("downloading") + } + } +} + +func (p *ytdlpAPIProgress) OnDone(ctx context.Context, task *ytdlp.Task, err error) { + finishAPITask(p.taskID, err) +} + +type aria2APIProgress struct{ taskID string } + +func newAria2APIProgress(taskID string) aria2dl.ProgressTracker { + return &aria2APIProgress{taskID: taskID} +} + +func (p *aria2APIProgress) OnStart(ctx context.Context, task *aria2dl.Task) { + if info, ok := GetTask(p.taskID); ok { + info.UpdateStatus(TaskStatusRunning) + info.UpdatePhase("downloading") + } +} + +func (p *aria2APIProgress) OnProgress(ctx context.Context, task *aria2dl.Task, status *aria2.Status) { + if info, ok := GetTask(p.taskID); ok { + total, _ := strconv.ParseInt(status.TotalLength, 10, 64) + completed, _ := strconv.ParseInt(status.CompletedLength, 10, 64) + info.UpdatePhase("downloading") + info.UpdateDownloadProgress(completed, total) + } +} + +func (p *aria2APIProgress) OnDone(ctx context.Context, task *aria2dl.Task, err error) { + finishAPITask(p.taskID, err) +} + +type parsedAPIProgress struct{ taskID string } + +func newParsedAPIProgress(taskID string) parsed.ProgressTracker { + return &parsedAPIProgress{taskID: taskID} +} + +func (p *parsedAPIProgress) OnStart(ctx context.Context, info parsed.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdateStatus(TaskStatusRunning) + task.UpdatePhase("downloading") + task.TotalFiles = int(info.TotalResources()) + task.UpdateDownloadProgress(info.DownloadedBytes(), info.TotalBytes()) + } +} + +func (p *parsedAPIProgress) OnProgress(ctx context.Context, info parsed.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdatePhase("downloading") + task.TotalFiles = int(info.TotalResources()) + task.DownloadedFiles = int(info.Downloaded()) + task.UpdateDownloadProgress(info.DownloadedBytes(), info.TotalBytes()) + } +} + +func (p *parsedAPIProgress) OnDone(ctx context.Context, info parsed.TaskInfo, err error) { + finishAPITask(p.taskID, err) +} + +type tfileAPIProgress struct{ taskID string } + +func newTFileAPIProgress(taskID string) tfile.ProgressTracker { + return &tfileAPIProgress{taskID: taskID} +} + +func (p *tfileAPIProgress) OnStart(ctx context.Context, info tfile.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdateStatus(TaskStatusRunning) + task.UpdatePhase("downloading") + task.TotalFiles = 1 + task.UpdateDownloadProgress(0, info.FileSize()) + } +} + +func (p *tfileAPIProgress) OnProgress(ctx context.Context, info tfile.TaskInfo, downloaded, total int64) { + if task, ok := GetTask(p.taskID); ok { + task.UpdatePhase("downloading") + task.UpdateDownloadProgress(downloaded, total) + } +} + +func (p *tfileAPIProgress) OnDone(ctx context.Context, info tfile.TaskInfo, err error) { + finishAPITask(p.taskID, err) +} + +type batchTFileAPIProgress struct{ taskID string } + +func newBatchTFileAPIProgress(taskID string) batchtfile.ProgressTracker { + return &batchTFileAPIProgress{taskID: taskID} +} + +func (p *batchTFileAPIProgress) OnStart(ctx context.Context, info batchtfile.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdateStatus(TaskStatusRunning) + task.UpdatePhase("downloading") + task.TotalFiles = info.Count() + task.UpdateDownloadProgress(0, info.TotalSize()) + } +} + +func (p *batchTFileAPIProgress) OnProgress(ctx context.Context, info batchtfile.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdatePhase("downloading") + task.TotalFiles = info.Count() + task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize()) + } +} + +func (p *batchTFileAPIProgress) OnDone(ctx context.Context, info batchtfile.TaskInfo, err error) { + finishAPITask(p.taskID, err) +} + +type telegraphAPIProgress struct{ taskID string } + +func newTelegraphAPIProgress(taskID string) tphtask.ProgressTracker { + return &telegraphAPIProgress{taskID: taskID} +} + +func (p *telegraphAPIProgress) OnStart(ctx context.Context, info tphtask.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdateStatus(TaskStatusRunning) + task.UpdatePhase("downloading") + task.TotalFiles = info.TotalPics() + } +} + +func (p *telegraphAPIProgress) OnProgress(ctx context.Context, info tphtask.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdatePhase("downloading") + task.DownloadedFiles = int(info.Downloaded()) + } +} + +func (p *telegraphAPIProgress) OnDone(ctx context.Context, info tphtask.TaskInfo, err error) { + finishAPITask(p.taskID, err) +} + +type transferAPIProgress struct{ taskID string } + +func newTransferAPIProgress(taskID string) transfer.ProgressTracker { + return &transferAPIProgress{taskID: taskID} +} + +func (p *transferAPIProgress) OnStart(ctx context.Context, info transfer.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdateStatus(TaskStatusRunning) + task.UpdatePhase(info.Phase()) + task.TotalFiles = info.Count() + task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize()) + task.UpdateUploadProgress(info.Uploaded(), info.TotalSize()) + task.SetTransferMeta(info.SourceStorageName(), info.SourcePath(), info.TargetStorageName(), info.TargetPath()) + } +} + +func (p *transferAPIProgress) OnProgress(ctx context.Context, info transfer.TaskInfo) { + if task, ok := GetTask(p.taskID); ok { + task.UpdatePhase(info.Phase()) + task.TotalFiles = info.Count() + task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize()) + task.UpdateUploadProgress(info.Uploaded(), info.TotalSize()) + task.SetTransferMeta(info.SourceStorageName(), info.SourcePath(), info.TargetStorageName(), info.TargetPath()) + } +} + +func (p *transferAPIProgress) OnDone(ctx context.Context, info transfer.TaskInfo, err error) { + if task, ok := GetTask(p.taskID); ok { + task.UpdateDownloadProgress(info.Downloaded(), info.TotalSize()) + task.UpdateUploadProgress(info.Uploaded(), info.TotalSize()) + task.SetTransferMeta(info.SourceStorageName(), info.SourcePath(), info.TargetStorageName(), info.TargetPath()) + } + finishAPITask(p.taskID, err) +} + +func finishAPITask(taskID string, err error) { + info, ok := GetTask(taskID) + if !ok { + return + } + if err != nil { + if errors.Is(err, context.Canceled) { + if info.Status != TaskStatusPaused { + info.UpdateStatus(TaskStatusCancelled) + info.UpdatePhase("cancelled") + } + return + } + info.SetError(err.Error()) + return + } + info.UpdateStatus(TaskStatusCompleted) + info.UpdatePhase("completed") +} diff --git a/api/types.go b/api/types.go index 5462f6c6..cea532b6 100644 --- a/api/types.go +++ b/api/types.go @@ -14,6 +14,7 @@ type TaskStatus string const ( TaskStatusQueued TaskStatus = "queued" TaskStatusRunning TaskStatus = "running" + TaskStatusPaused TaskStatus = "paused" TaskStatusCompleted TaskStatus = "completed" TaskStatusFailed TaskStatus = "failed" TaskStatusCancelled TaskStatus = "cancelled" @@ -40,22 +41,28 @@ type CreateTaskResponse struct { type TaskProgress struct { TotalBytes int64 `json:"total_bytes,omitempty"` DownloadedBytes int64 `json:"downloaded_bytes,omitempty"` + UploadedBytes int64 `json:"uploaded_bytes,omitempty"` Percent float64 `json:"percent,omitempty"` SpeedMBPS float64 `json:"speed_mbps,omitempty"` } // TaskInfoResponse 任务信息响应 type TaskInfoResponse struct { - TaskID string `json:"task_id"` - Type tasktype.TaskType `json:"type"` - Status TaskStatus `json:"status"` - Title string `json:"title"` - Progress *TaskProgress `json:"progress,omitempty"` - Storage string `json:"storage"` - Path string `json:"path"` - Error string `json:"error,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + TaskID string `json:"task_id"` + Type tasktype.TaskType `json:"type"` + Status TaskStatus `json:"status"` + Title string `json:"title"` + Progress *TaskProgress `json:"progress,omitempty"` + Storage string `json:"storage"` + Path string `json:"path"` + SourceStorage string `json:"source_storage,omitempty"` + SourcePath string `json:"source_path,omitempty"` + TargetStorage string `json:"target_storage,omitempty"` + TargetPath string `json:"target_path,omitempty"` + Phase string `json:"phase,omitempty"` + Error string `json:"error,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // TasksListResponse 任务列表响应 @@ -150,6 +157,22 @@ type TransferParams struct { TargetPath string `json:"target_path"` } +type UpdateTaskPathRequest struct { + Path string `json:"path"` +} + +type ProxyTestRequest struct { + URL string `json:"url"` + Target string `json:"target,omitempty"` +} + +type ProxyTestResponse struct { + OK bool `json:"ok"` + MS int64 `json:"ms"` + Message string `json:"message,omitempty"` + Target string `json:"target"` +} + // TGFilesParams tgfiles 任务参数 type TGFilesParams struct { MessageLinks []string `json:"message_links"` diff --git a/cmd/run.go b/cmd/run.go index b1390a2d..00e96a2b 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "io" "os" "path/filepath" "strings" @@ -15,6 +16,7 @@ import ( userclient "github.com/krau/SaveAny-Bot/client/user" "github.com/krau/SaveAny-Bot/common/cache" "github.com/krau/SaveAny-Bot/common/i18n" + "github.com/krau/SaveAny-Bot/common/logbuffer" "github.com/krau/SaveAny-Bot/common/utils/fsutil" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/core" @@ -26,7 +28,7 @@ import ( func Run(cmd *cobra.Command, _ []string) { ctx, cancel := context.WithCancel(cmd.Context()) - logger := log.NewWithOptions(os.Stdout, log.Options{ + logger := log.NewWithOptions(io.MultiWriter(os.Stdout, logbuffer.Default()), log.Options{ Level: log.InfoLevel, ReportTimestamp: true, TimeFormat: time.TimeOnly, @@ -51,10 +53,12 @@ func Run(cmd *cobra.Command, _ []string) { if err != nil { logger.Fatal("Init failed", "error", err) } - go func() { - <-exitChan - cancel() - }() + if exitChan != nil { + go func() { + <-exitChan + cancel() + }() + } core.Run(ctx) @@ -89,6 +93,10 @@ func initAll(ctx context.Context) (<-chan struct{}, error) { if err := api.Start(ctx); err != nil { logger.Error("Failed to start API server", "error", err) } + if strings.TrimSpace(config.C().Telegram.Token) == "" { + logger.Warn("Telegram bot token is empty; skip bot initialization and keep the config web/API server running") + return nil, nil + } return bot.Init(ctx), nil } diff --git a/cmd/web.go b/cmd/web.go new file mode 100644 index 00000000..54e5454d --- /dev/null +++ b/cmd/web.go @@ -0,0 +1,67 @@ +package cmd + +import ( + "fmt" + "io" + "os" + "time" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/api" + "github.com/krau/SaveAny-Bot/common/logbuffer" + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/database" + "github.com/spf13/cobra" +) + +var webCmd = &cobra.Command{ + Use: "web", + Short: "Start the visual configuration web UI", + Run: runWeb, +} + +func init() { + webCmd.Flags().StringP("config", "c", config.DefaultConfigFile, "config file path") + webCmd.Flags().String("host", "0.0.0.0", "web UI listen host") + webCmd.Flags().Int("port", config.DefaultAPIPort, "web UI listen port") + webCmd.Flags().String("token", "", "web UI API token") + rootCmd.AddCommand(webCmd) +} + +func runWeb(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + logger := log.NewWithOptions(io.MultiWriter(os.Stdout, logbuffer.Default()), log.Options{ + Level: log.InfoLevel, + ReportTimestamp: true, + TimeFormat: time.TimeOnly, + ReportCaller: true, + }) + log.SetDefault(logger) + ctx = log.WithContext(ctx, logger) + + configPath, _ := cmd.Flags().GetString("config") + host, _ := cmd.Flags().GetString("host") + port, _ := cmd.Flags().GetInt("port") + token, _ := cmd.Flags().GetString("token") + + if _, err := os.Stat(configPath); err == nil { + if err := config.Init(ctx, configPath); err != nil { + logger.Warn("Config file exists but could not be loaded; web editor will still start", "error", err) + } else if err := database.Open(ctx); err != nil { + logger.Warn("Database could not be opened; rule editor will become available after a valid config is saved", "error", err) + } + } + if token == "" && host != "127.0.0.1" && host != "localhost" { + logger.Warn("Config web UI is listening without a token", "host", host) + } + if _, err := api.StartConfigWebServer(ctx, api.ConfigWebServerOptions{ + ConfigPath: configPath, + Host: host, + Port: port, + Token: token, + }); err != nil { + logger.Fatal("Failed to start config web server", "error", err) + } + fmt.Printf("Config web UI: http://%s:%d/config\n", host, port) + <-ctx.Done() +} diff --git a/common/logbuffer/logbuffer.go b/common/logbuffer/logbuffer.go new file mode 100644 index 00000000..1dd38435 --- /dev/null +++ b/common/logbuffer/logbuffer.go @@ -0,0 +1,79 @@ +package logbuffer + +import ( + "bytes" + "regexp" + "strings" + "sync" +) + +const defaultLimit = 600 + +var ansiPattern = regexp.MustCompile(`\x1b\[[0-9;]*[A-Za-z]`) + +type Buffer struct { + mu sync.RWMutex + limit int + lines []string + partial string +} + +var defaultBuffer = New(defaultLimit) + +func Default() *Buffer { + return defaultBuffer +} + +func New(limit int) *Buffer { + if limit < 1 { + limit = defaultLimit + } + return &Buffer{limit: limit} +} + +func (b *Buffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + + text := b.partial + string(p) + parts := strings.Split(text, "\n") + b.partial = parts[len(parts)-1] + for _, line := range parts[:len(parts)-1] { + b.appendLine(line) + } + return len(p), nil +} + +func (b *Buffer) Lines(limit int) []string { + b.mu.RLock() + defer b.mu.RUnlock() + + lines := b.lines + if b.partial != "" { + lines = append(append([]string{}, b.lines...), cleanLine(b.partial)) + } + if limit < 1 || limit > len(lines) { + limit = len(lines) + } + out := make([]string, limit) + copy(out, lines[len(lines)-limit:]) + return out +} + +func (b *Buffer) appendLine(line string) { + line = cleanLine(line) + if strings.TrimSpace(line) == "" { + return + } + b.lines = append(b.lines, line) + if len(b.lines) > b.limit { + copy(b.lines, b.lines[len(b.lines)-b.limit:]) + b.lines = b.lines[:b.limit] + } +} + +func cleanLine(line string) string { + line = ansiPattern.ReplaceAllString(line, "") + line = strings.TrimRight(line, "\r") + return string(bytes.TrimRight([]byte(line), "\x00")) +} diff --git a/config/edit.go b/config/edit.go new file mode 100644 index 00000000..2af45049 --- /dev/null +++ b/config/edit.go @@ -0,0 +1,743 @@ +package config + +import ( + "fmt" + "net/url" + "os" + "path/filepath" + "regexp" + "slices" + "strconv" + "strings" + + storenum "github.com/krau/SaveAny-Bot/pkg/enums/storage" + "github.com/pelletier/go-toml/v2" + "github.com/spf13/viper" +) + +const ( + DefaultConfigFile = "config.toml" + DefaultAPIPort = 19191 +) + +type EditableConfig struct { + Lang string `toml:"lang" json:"lang"` + Workers int `toml:"workers" json:"workers"` + Retry int `toml:"retry" json:"retry"` + Threads int `toml:"threads" json:"threads"` + Stream bool `toml:"stream" json:"stream"` + NoCleanCache bool `toml:"no_clean_cache" json:"no_clean_cache"` + Proxy string `toml:"proxy" json:"proxy"` + Log EditableLogConfig `toml:"log" json:"log"` + Telegram EditableTelegramConfig `toml:"telegram" json:"telegram"` + Aria2 EditableAria2Config `toml:"aria2" json:"aria2"` + API EditableAPIConfig `toml:"api" json:"api"` + Cache EditableCacheConfig `toml:"cache" json:"cache"` + Temp EditableTempConfig `toml:"temp" json:"temp"` + DB EditableDBConfig `toml:"db" json:"db"` + Parser EditableParserConfig `toml:"parser" json:"parser"` + Hook EditableHookConfig `toml:"hook" json:"hook"` + Storages []map[string]any `toml:"storages" json:"storages"` + Users []EditableUserConfig `toml:"users" json:"users"` +} + +type EditableLogConfig struct { + Level string `toml:"level" json:"level"` +} + +type EditableTelegramConfig struct { + Token string `toml:"token" json:"token"` + AppID int `toml:"app_id" json:"app_id"` + AppHash string `toml:"app_hash" json:"app_hash"` + Proxy EditableTelegramProxyConfig `toml:"proxy" json:"proxy"` + RpcRetry int `toml:"rpc_retry" json:"rpc_retry"` + Userbot EditableUserbotConfig `toml:"userbot" json:"userbot"` + MediaGroupTimeout int `toml:"media_group_timeout" json:"media_group_timeout"` +} + +type EditableTelegramProxyConfig struct { + Enable bool `toml:"enable" json:"enable"` + URL string `toml:"url" json:"url"` +} + +type EditableUserbotConfig struct { + Enable bool `toml:"enable" json:"enable"` + Session string `toml:"session" json:"session"` +} + +type EditableAria2Config struct { + Enable bool `toml:"enable" json:"enable"` + Url string `toml:"url" json:"url"` + Secret string `toml:"secret" json:"secret"` + KeepFile bool `toml:"keep_file" json:"keep_file"` +} + +type EditableAPIConfig struct { + Enable bool `toml:"enable" json:"enable"` + Host string `toml:"host" json:"host"` + Port int `toml:"port" json:"port"` + Token string `toml:"token" json:"token"` +} + +type EditableCacheConfig struct { + TTL int64 `toml:"ttl" json:"ttl"` + NumCounters int64 `toml:"num_counters" json:"num_counters"` + MaxCost int64 `toml:"max_cost" json:"max_cost"` +} + +type EditableTempConfig struct { + BasePath string `toml:"base_path" json:"base_path"` +} + +type EditableDBConfig struct { + Path string `toml:"path" json:"path"` + Session string `toml:"session" json:"session"` +} + +type EditableParserConfig struct { + PluginEnable bool `toml:"plugin_enable" json:"plugin_enable"` + PluginDirs []string `toml:"plugin_dirs" json:"plugin_dirs"` + Proxy string `toml:"proxy" json:"proxy"` + ParserCfgs map[string]map[string]any `toml:",inline" json:"parser_cfgs,omitempty"` +} + +type EditableHookConfig struct { + Exec EditableHookExecConfig `toml:"exec" json:"exec"` +} + +type EditableHookExecConfig struct { + TaskBeforeStart string `toml:"task_before_start" json:"task_before_start"` + TaskSuccess string `toml:"task_success" json:"task_success"` + TaskFail string `toml:"task_fail" json:"task_fail"` + TaskCancel string `toml:"task_cancel" json:"task_cancel"` +} + +type EditableUserConfig struct { + ID int64 `toml:"id" json:"id"` + Storages []string `toml:"storages" json:"storages"` + Blacklist bool `toml:"blacklist" json:"blacklist"` +} + +type EditableConfigFile struct { + Path string `json:"path"` + Exists bool `json:"exists"` + Config EditableConfig `json:"config"` +} + +type StorageFieldSchema struct { + Name string `json:"name"` + Label string `json:"label"` + Type string `json:"type"` + Required bool `json:"required"` + Secret bool `json:"secret,omitempty"` + Placeholder string `json:"placeholder,omitempty"` + Help string `json:"help,omitempty"` + Options []string `json:"options,omitempty"` +} + +type StorageTypeSchema struct { + Type string `json:"type"` + Label string `json:"label"` + Fields []StorageFieldSchema `json:"fields"` +} + +func ResolveConfigFilePath(path string) string { + if path != "" { + return path + } + if used := viper.ConfigFileUsed(); used != "" { + return used + } + return DefaultConfigFile +} + +func ConfigFileUsed() string { + return viper.ConfigFileUsed() +} + +func DefaultEditableConfig() EditableConfig { + return EditableConfig{ + Lang: "zh-Hans", + Workers: 4, + Retry: 3, + Threads: 4, + Log: EditableLogConfig{ + Level: "debug", + }, + Telegram: EditableTelegramConfig{ + AppID: 1025907, + AppHash: "452b0359b988148995f22ff0f4229750", + RpcRetry: 5, + Userbot: EditableUserbotConfig{ + Session: "data/usersession.db", + }, + }, + Aria2: EditableAria2Config{ + Url: "http://localhost:6800/jsonrpc", + }, + API: EditableAPIConfig{ + Host: "0.0.0.0", + Port: DefaultAPIPort, + }, + Cache: EditableCacheConfig{ + TTL: 86400, + NumCounters: 100000, + MaxCost: 1000000, + }, + Temp: EditableTempConfig{ + BasePath: "cache/", + }, + DB: EditableDBConfig{ + Path: "data/saveany.db", + Session: "data/session.db", + }, + Parser: EditableParserConfig{ + PluginDirs: []string{"plugins"}, + }, + Storages: []map[string]any{ + { + "name": "local", + "type": storenum.Local.String(), + "enable": true, + "base_path": "./downloads", + }, + }, + Users: []EditableUserConfig{}, + } +} + +func LoadEditableConfig(path string) (*EditableConfigFile, error) { + path = ResolveConfigFilePath(path) + cfg := DefaultEditableConfig() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &EditableConfigFile{Path: path, Exists: false, Config: cfg}, nil + } + return nil, fmt.Errorf("failed to read config file %s: %w", path, err) + } + if strings.TrimSpace(string(data)) == "" { + return &EditableConfigFile{Path: path, Exists: true, Config: cfg}, nil + } + if err := toml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse config file %s: %w", path, err) + } + NormalizeEditableConfig(&cfg) + return &EditableConfigFile{Path: path, Exists: true, Config: cfg}, nil +} + +func SaveEditableConfig(path string, cfg *EditableConfig) error { + path = ResolveConfigFilePath(path) + if isRemoteConfigPath(path) { + return fmt.Errorf("remote config files are read-only in the web editor") + } + NormalizeEditableConfig(cfg) + if err := ValidateEditableConfig(*cfg); err != nil { + return err + } + data, err := toml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to encode config: %w", err) + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil && filepath.Dir(path) != "." { + return fmt.Errorf("failed to create config directory: %w", err) + } + mode := os.FileMode(0644) + if st, err := os.Stat(path); err == nil { + mode = st.Mode().Perm() + if old, err := os.ReadFile(path); err == nil { + _ = os.WriteFile(path+".bak", old, mode) + } + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, mode); err != nil { + return fmt.Errorf("failed to write temporary config file: %w", err) + } + if err := os.Rename(tmp, path); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("failed to replace config file: %w", err) + } + return nil +} + +func NormalizeEditableConfig(cfg *EditableConfig) { + defaults := DefaultEditableConfig() + if cfg.Lang == "" { + cfg.Lang = defaults.Lang + } + if cfg.Workers < 1 { + cfg.Workers = defaults.Workers + } + if cfg.Retry < 1 { + cfg.Retry = defaults.Retry + } + if cfg.Threads < 1 { + cfg.Threads = defaults.Threads + } + if cfg.Log.Level == "" { + cfg.Log.Level = defaults.Log.Level + } + if cfg.Telegram.AppID == 0 { + cfg.Telegram.AppID = defaults.Telegram.AppID + } + if cfg.Telegram.AppHash == "" { + cfg.Telegram.AppHash = defaults.Telegram.AppHash + } + if cfg.Telegram.RpcRetry == 0 { + cfg.Telegram.RpcRetry = defaults.Telegram.RpcRetry + } + if cfg.Telegram.Userbot.Session == "" { + cfg.Telegram.Userbot.Session = defaults.Telegram.Userbot.Session + } + if cfg.Aria2.Url == "" { + cfg.Aria2.Url = defaults.Aria2.Url + } + if cfg.API.Host == "" { + cfg.API.Host = defaults.API.Host + } + if cfg.API.Port == 0 { + cfg.API.Port = defaults.API.Port + } + if cfg.Cache.TTL == 0 { + cfg.Cache.TTL = defaults.Cache.TTL + } + if cfg.Cache.NumCounters == 0 { + cfg.Cache.NumCounters = defaults.Cache.NumCounters + } + if cfg.Cache.MaxCost == 0 { + cfg.Cache.MaxCost = defaults.Cache.MaxCost + } + if cfg.Temp.BasePath == "" { + cfg.Temp.BasePath = defaults.Temp.BasePath + } + if cfg.DB.Path == "" { + cfg.DB.Path = defaults.DB.Path + } + if cfg.DB.Session == "" { + cfg.DB.Session = defaults.DB.Session + } + if cfg.Parser.PluginDirs == nil { + cfg.Parser.PluginDirs = defaults.Parser.PluginDirs + } + if cfg.Storages == nil { + cfg.Storages = []map[string]any{} + } + if cfg.Users == nil { + cfg.Users = []EditableUserConfig{} + } + for i := range cfg.Storages { + normalizeStorageMap(cfg.Storages[i]) + fillDefaultStoragePath(cfg.Storages[i]) + } + storageNames := editableStorageNameSet(cfg.Storages) + for i := range cfg.Users { + if cfg.Users[i].Storages == nil { + cfg.Users[i].Storages = []string{} + } + cfg.Users[i].Storages = filterKnownStorageNames(cfg.Users[i].Storages, storageNames) + } +} + +func ValidateEditableConfig(cfg EditableConfig) error { + if cfg.Workers < 1 { + return fmt.Errorf("workers must be greater than 0") + } + if cfg.Retry < 1 { + return fmt.Errorf("retry must be greater than 0") + } + if cfg.Threads < 1 { + return fmt.Errorf("threads must be greater than 0") + } + if !slices.Contains([]string{"trace", "debug", "info", "warn", "error", "fatal"}, strings.ToLower(cfg.Log.Level)) { + return fmt.Errorf("invalid log level: %s", cfg.Log.Level) + } + if err := validateProxyURL("proxy", cfg.Proxy, false); err != nil { + return err + } + if err := validateProxyURL("telegram.proxy.url", cfg.Telegram.Proxy.URL, cfg.Telegram.Proxy.Enable); err != nil { + return err + } + if err := validateProxyURL("parser.proxy", cfg.Parser.Proxy, false); err != nil { + return err + } + if cfg.API.Port < 1 || cfg.API.Port > 65535 { + return fmt.Errorf("api.port must be between 1 and 65535") + } + if cfg.Cache.TTL < 0 || cfg.Cache.NumCounters < 0 || cfg.Cache.MaxCost < 0 { + return fmt.Errorf("cache values must not be negative") + } + if cfg.DB.Path == "" { + return fmt.Errorf("db.path is required") + } + if cfg.DB.Session == "" { + return fmt.Errorf("db.session is required") + } + storageNames := make(map[string]struct{}, len(cfg.Storages)) + for i, storage := range cfg.Storages { + if err := validateEditableStorage(i, storage); err != nil { + return err + } + name := getStringValue(storage, "name") + if _, ok := storageNames[name]; ok { + return fmt.Errorf("duplicate storage name: %s", name) + } + storageNames[name] = struct{}{} + } + userIDs := make(map[int64]struct{}, len(cfg.Users)) + for _, user := range cfg.Users { + if user.ID == 0 { + return fmt.Errorf("user id is required") + } + if _, ok := userIDs[user.ID]; ok { + return fmt.Errorf("duplicate user id: %d", user.ID) + } + userIDs[user.ID] = struct{}{} + for _, storageName := range user.Storages { + if _, ok := storageNames[storageName]; !ok { + return fmt.Errorf("user %d references unknown storage %s", user.ID, storageName) + } + } + } + return nil +} + +func StorageSchemas() []StorageTypeSchema { + return []StorageTypeSchema{ + { + Type: storenum.Local.String(), + Label: "Local", + Fields: []StorageFieldSchema{ + {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "./downloads"}, + }, + }, + { + Type: storenum.Webdav.String(), + Label: "WebDAV", + Fields: []StorageFieldSchema{ + {Name: "url", Label: "URL", Type: "url", Required: true, Placeholder: "https://example.com/dav"}, + {Name: "username", Label: "用户名", Type: "string", Required: true}, + {Name: "password", Label: "密码", Type: "password", Required: true, Secret: true}, + {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "/telegram"}, + }, + }, + { + Type: storenum.Alist.String(), + Label: "AList", + Fields: []StorageFieldSchema{ + {Name: "url", Label: "URL", Type: "url", Required: true, Placeholder: "https://alist.example.com"}, + {Name: "username", Label: "用户名", Type: "string"}, + {Name: "password", Label: "密码", Type: "password", Secret: true}, + {Name: "token", Label: "Token", Type: "password", Secret: true, Help: "Token 与用户名密码二选一"}, + {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "/telegram"}, + {Name: "token_exp", Label: "Token 过期时间", Type: "int"}, + }, + }, + { + Type: storenum.Minio.String(), + Label: "MinIO", + Fields: objectStorageFields(false), + }, + { + Type: storenum.S3.String(), + Label: "S3", + Fields: objectStorageFields(true), + }, + { + Type: storenum.Telegram.String(), + Label: "Telegram", + Fields: []StorageFieldSchema{ + {Name: "chat_id", Label: "Chat ID", Type: "int", Required: true, Placeholder: "-1001234567890"}, + {Name: "force_file", Label: "强制文件模式", Type: "bool"}, + {Name: "rate_limit", Label: "速率限制", Type: "int"}, + {Name: "rate_burst", Label: "突发限制", Type: "int"}, + {Name: "skip_large", Label: "跳过超大文件", Type: "bool"}, + {Name: "split_size_mb", Label: "分卷大小 MB", Type: "int"}, + }, + }, + { + Type: storenum.Rclone.String(), + Label: "Rclone", + Fields: []StorageFieldSchema{ + {Name: "remote", Label: "Remote", Type: "string", Required: true, Placeholder: "remote:"}, + {Name: "base_path", Label: "根路径", Type: "string", Placeholder: "/telegram"}, + {Name: "config_path", Label: "配置文件路径", Type: "string"}, + {Name: "flags", Label: "额外参数", Type: "string_list", Placeholder: "--transfers=4"}, + }, + }, + } +} + +func StorageTypeNames() []string { + return storenum.StorageTypeNames() +} + +func objectStorageFields(includeS3Only bool) []StorageFieldSchema { + fields := []StorageFieldSchema{ + {Name: "endpoint", Label: "Endpoint", Type: "string", Required: true, Placeholder: "s3.amazonaws.com"}, + {Name: "access_key_id", Label: "Access Key ID", Type: "string", Required: true}, + {Name: "secret_access_key", Label: "Secret Access Key", Type: "password", Required: true, Secret: true}, + {Name: "bucket_name", Label: "Bucket", Type: "string", Required: true}, + {Name: "use_ssl", Label: "Use SSL", Type: "bool"}, + {Name: "base_path", Label: "根路径", Type: "string", Required: true, Placeholder: "telegram"}, + } + if includeS3Only { + fields = append(fields, + StorageFieldSchema{Name: "region", Label: "Region", Type: "string", Placeholder: "ap-east-1"}, + StorageFieldSchema{Name: "virtual_host", Label: "Virtual Host", Type: "bool"}, + ) + } + return fields +} + +func validateEditableStorage(index int, storage map[string]any) error { + name := getStringValue(storage, "name") + if name == "" { + return fmt.Errorf("storages[%d].name is required", index) + } + storageType := getStringValue(storage, "type") + if storageType == "" { + return fmt.Errorf("storage %s type is required", name) + } + parsedType, err := storenum.ParseStorageType(storageType) + if err != nil { + return fmt.Errorf("invalid storage type %s for %s: %w", storageType, name, err) + } + if !getBoolValue(storage, "enable") { + return nil + } + required := func(key string) error { + if strings.TrimSpace(getStringValue(storage, key)) == "" { + return fmt.Errorf("%s is required for %s storage %s", key, parsedType, name) + } + return nil + } + switch parsedType { + case storenum.Local: + return required("base_path") + case storenum.Webdav: + for _, key := range []string{"url", "username", "password", "base_path"} { + if err := required(key); err != nil { + return err + } + } + case storenum.Alist: + if err := required("url"); err != nil { + return err + } + if getStringValue(storage, "token") == "" && (getStringValue(storage, "username") == "" || getStringValue(storage, "password") == "") { + return fmt.Errorf("username and password or token is required for alist storage %s", name) + } + return required("base_path") + case storenum.Minio, storenum.S3: + for _, key := range []string{"endpoint", "access_key_id", "secret_access_key", "bucket_name", "base_path"} { + if err := required(key); err != nil { + return err + } + } + case storenum.Telegram: + if getInt64Value(storage, "chat_id") == 0 { + return fmt.Errorf("chat_id is required for telegram storage %s", name) + } + if getInt64Value(storage, "rate_limit") < 0 || getInt64Value(storage, "rate_burst") < 0 { + return fmt.Errorf("rate_limit and rate_burst must not be negative for telegram storage %s", name) + } + case storenum.Rclone: + return required("remote") + } + return nil +} + +func fillDefaultStoragePath(storage map[string]any) { + defaultPath, ok := defaultStorageBasePath(getStringValue(storage, "type")) + if !ok { + return + } + if strings.TrimSpace(getStringValue(storage, "base_path")) == "" { + storage["base_path"] = defaultPath + } +} + +func defaultStorageBasePath(storageType string) (string, bool) { + switch storageType { + case storenum.Local.String(): + return "./downloads", true + case storenum.Webdav.String(), storenum.Alist.String(), storenum.Rclone.String(): + return "/telegram", true + case storenum.Minio.String(), storenum.S3.String(): + return "telegram", true + default: + return "", false + } +} + +func editableStorageNameSet(storages []map[string]any) map[string]struct{} { + names := make(map[string]struct{}, len(storages)) + for _, storage := range storages { + name := getStringValue(storage, "name") + if name != "" { + names[name] = struct{}{} + } + } + return names +} + +func validateProxyURL(field, value string, required bool) error { + value = strings.TrimSpace(value) + if value == "" { + if required { + return fmt.Errorf("%s is required", field) + } + return nil + } + u, err := url.Parse(value) + if err != nil { + return fmt.Errorf("invalid %s: %w", field, err) + } + switch strings.ToLower(u.Scheme) { + case "http", "https", "socks5", "socks5h": + return nil + default: + return fmt.Errorf("%s must use http, https, socks5, or socks5h", field) + } +} + +func isRemoteConfigPath(path string) bool { + return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") +} + +func normalizeStorageMap(storage map[string]any) { + for _, schema := range StorageSchemas() { + if getStringValue(storage, "type") != schema.Type { + continue + } + for _, field := range schema.Fields { + switch field.Type { + case "int": + if value, ok := maybeInt64(storage[field.Name]); ok { + storage[field.Name] = value + } + case "bool": + if value, ok := maybeBool(storage[field.Name]); ok { + storage[field.Name] = value + } + case "string_list": + if value, ok := maybeStringList(storage[field.Name]); ok { + storage[field.Name] = value + } + default: + if value, ok := storage[field.Name]; ok { + storage[field.Name] = strings.TrimSpace(fmt.Sprint(value)) + } + } + } + break + } + if value, ok := storage["enable"]; ok { + if b, ok := maybeBool(value); ok { + storage["enable"] = b + } + } +} + +func getStringValue(values map[string]any, key string) string { + value, ok := values[key] + if !ok || value == nil { + return "" + } + return strings.TrimSpace(fmt.Sprint(value)) +} + +func getBoolValue(values map[string]any, key string) bool { + value, ok := values[key] + if !ok { + return false + } + b, _ := maybeBool(value) + return b +} + +func getInt64Value(values map[string]any, key string) int64 { + value, ok := values[key] + if !ok { + return 0 + } + i, _ := maybeInt64(value) + return i +} + +func maybeBool(value any) (bool, bool) { + switch v := value.(type) { + case bool: + return v, true + case string: + b, err := strconv.ParseBool(strings.TrimSpace(v)) + return b, err == nil + default: + return false, false + } +} + +func maybeInt64(value any) (int64, bool) { + switch v := value.(type) { + case int: + return int64(v), true + case int8: + return int64(v), true + case int16: + return int64(v), true + case int32: + return int64(v), true + case int64: + return v, true + case uint: + return int64(v), true + case uint8: + return int64(v), true + case uint16: + return int64(v), true + case uint32: + return int64(v), true + case uint64: + if v > uint64(^uint64(0)>>1) { + return 0, false + } + return int64(v), true + case float32: + return int64(v), float32(int64(v)) == v + case float64: + return int64(v), float64(int64(v)) == v + case string: + i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64) + return i, err == nil + default: + return 0, false + } +} + +func maybeStringList(value any) ([]string, bool) { + switch v := value.(type) { + case []string: + return v, true + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + out = append(out, strings.TrimSpace(fmt.Sprint(item))) + } + return out, true + case string: + if strings.TrimSpace(v) == "" { + return []string{}, true + } + parts := regexp.MustCompile(`[\n,]+`).Split(v, -1) + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + out = append(out, part) + } + } + return out, true + default: + return nil, false + } +} diff --git a/config/viper.go b/config/viper.go index af0de879..9aa74313 100644 --- a/config/viper.go +++ b/config/viper.go @@ -129,7 +129,7 @@ func Init(ctx context.Context, configFile ...string) error { // API "api.enable": false, "api.host": "0.0.0.0", - "api.port": 8080, + "api.port": DefaultAPIPort, "api.token": "", } diff --git a/core/tasks/transfer/execute.go b/core/tasks/transfer/execute.go index dc57da67..e78befd5 100644 --- a/core/tasks/transfer/execute.go +++ b/core/tasks/transfer/execute.go @@ -2,6 +2,7 @@ package transfer import ( "context" + "errors" "fmt" "io" "os" @@ -9,17 +10,23 @@ import ( "path/filepath" "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/common/utils/ioutil" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/storage" "golang.org/x/sync/errgroup" ) +var errTargetPathChanged = errors.New("target path changed") + // Execute implements core.Executable. func (t *Task) Execute(ctx context.Context) error { logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("transfer[%s]", t.ID)) logger.Info("Starting transfer task") - t.Progress.OnStart(ctx, t) + t.setPhase("running") + if t.Progress != nil { + t.Progress.OnStart(ctx, t) + } workers := config.C().Workers eg, gctx := errgroup.WithContext(ctx) @@ -65,12 +72,15 @@ func (t *Task) Execute(ctx context.Context) error { logger.Info("Transfer task completed successfully") } - t.Progress.OnDone(ctx, t, err) + if t.Progress != nil { + t.Progress.OnDone(ctx, t, err) + } return err } func (t *Task) processElement(ctx context.Context, elem TaskElement) error { logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", elem.FileInfo.Name)) + uploadedSize := elem.FileInfo.Size // Check whether the source storage supports reading readableStorage, ok := elem.SourceStorage.(storage.StorageReadable) @@ -78,26 +88,38 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { return fmt.Errorf("source storage %s does not support reading", elem.SourceStorage.Name()) } - logger.Info("Opening file from source storage") - reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) - } - defer reader.Close() - - // Build target storage path: /target_path/filename - storagePath := path.Join(elem.TargetPath, elem.FileInfo.Name) - - // Inject file size into context - ctx = context.WithValue(ctx, ctxkey.ContentLength, size) - if config.C().Stream { - if err := elem.TargetStorage.Save(ctx, reader, storagePath); err != nil { - return fmt.Errorf("failed to upload file to storage: %w", err) + for { + logger.Info("Opening file from source storage") + reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + uploadedSize = size + err = t.uploadReader(ctx, elem, reader, size) + closeErr := reader.Close() + if err == nil && closeErr != nil { + err = closeErr + } + if errors.Is(err, errTargetPathChanged) { + continue + } + if err != nil { + return fmt.Errorf("failed to upload file to storage: %w", err) + } + break } } else { + logger.Info("Opening file from source storage") + reader, size, err := readableStorage.OpenFile(ctx, elem.SourcePath) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + uploadedSize = size + defer reader.Close() + logger.Info("Downloading to temporary file for ReadSeeker support") - tempFile, err := t.downloadToTemp(reader, elem.FileInfo.Name) + tempFile, err := t.downloadToTemp(ctx, reader, elem.FileInfo.Name) if err != nil { return fmt.Errorf("failed to download to temp: %w", err) } @@ -109,19 +131,30 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { } logger.Infof("Uploading file to storage (size: %d bytes)", size) - if err := elem.TargetStorage.Save(ctx, tempFile, storagePath); err != nil { - return fmt.Errorf("failed to upload file to storage: %w", err) + for { + if _, err := tempFile.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek temp file: %w", err) + } + if err := t.uploadReader(ctx, elem, tempFile, size); err != nil { + if errors.Is(err, errTargetPathChanged) { + continue + } + return fmt.Errorf("failed to upload file to storage: %w", err) + } + break } } - t.uploaded.Add(size) - t.Progress.OnProgress(ctx, t) + t.uploaded.Add(uploadedSize) + if t.Progress != nil { + t.Progress.OnProgress(ctx, t) + } logger.Info("File uploaded successfully") return nil } -func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, error) { +func (t *Task) downloadToTemp(ctx context.Context, reader io.Reader, filename string) (*os.File, error) { tempDir := config.C().Temp.BasePath if tempDir == "" { tempDir = os.TempDir() @@ -132,7 +165,18 @@ func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, erro return nil, fmt.Errorf("failed to create temp file: %w", err) } - if _, err := io.Copy(tempFile, reader); err != nil { + t.setPhase("downloading") + if t.Progress != nil { + t.Progress.OnProgress(ctx, t) + } + wr := ioutil.NewProgressWriter(tempFile, func(n int) { + t.downloaded.Add(int64(n)) + if t.Progress != nil { + t.Progress.OnProgress(ctx, t) + } + }) + + if _, err := io.Copy(wr, reader); err != nil { tempFile.Close() os.Remove(tempFile.Name()) return nil, fmt.Errorf("failed to copy to temp file: %w", err) @@ -140,3 +184,51 @@ func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, erro return tempFile, nil } + +func (t *Task) uploadReader(ctx context.Context, elem TaskElement, reader io.Reader, size int64) error { + version := t.pathVersion.Load() + uploadCtx, cancel := context.WithCancel(ctx) + t.registerUploadCancel(elem.ID, cancel) + defer func() { + cancel() + t.clearUploadCancel(elem.ID) + }() + + t.setPhase("uploading") + if t.Progress != nil { + t.Progress.OnProgress(uploadCtx, t) + } + storagePath := path.Join(t.currentTargetPath(), elem.FileInfo.Name) + uploadCtx = context.WithValue(uploadCtx, ctxkey.ContentLength, size) + progressReader := ioutil.NewProgressReader(asReadSeeker(reader), size, func(read int64, total int64) { + t.setUploadingBytes(elem.ID, read) + if t.Progress != nil { + t.Progress.OnProgress(uploadCtx, t) + } + }) + if err := elem.TargetStorage.Save(uploadCtx, progressReader, storagePath); err != nil { + if uploadCtx.Err() != nil && ctx.Err() == nil && version != t.pathVersion.Load() { + return errTargetPathChanged + } + return err + } + return nil +} + +func asReadSeeker(reader io.Reader) io.ReadSeeker { + if rs, ok := reader.(io.ReadSeeker); ok { + return rs + } + return readSeekerAdapter{Reader: reader} +} + +type readSeekerAdapter struct { + io.Reader +} + +func (r readSeekerAdapter) Seek(offset int64, whence int) (int64, error) { + if offset == 0 && whence == io.SeekStart { + return 0, nil + } + return 0, fmt.Errorf("reader is not seekable") +} diff --git a/core/tasks/transfer/task.go b/core/tasks/transfer/task.go index 85e2bdef..aefd6725 100644 --- a/core/tasks/transfer/task.go +++ b/core/tasks/transfer/task.go @@ -30,11 +30,20 @@ type Task struct { elems []TaskElement Progress ProgressTracker IgnoreErrors bool + downloaded atomic.Int64 uploaded atomic.Int64 totalSize int64 processing map[string]TaskElementInfo processingMu sync.RWMutex failed map[string]error + targetPathMu sync.RWMutex + targetPath string + pathVersion atomic.Int64 + phaseMu sync.RWMutex + phase string + uploadMu sync.RWMutex + uploading map[string]int64 + uploadCancel map[string]context.CancelFunc } // Title implements core.Executable. @@ -81,7 +90,14 @@ func NewTransferTask( ctx: ctx, elems: elems, Progress: progress, + phase: "queued", uploaded: atomic.Int64{}, + targetPath: func() string { + if len(elems) == 0 { + return "" + } + return elems[0].TargetPath + }(), totalSize: func() int64 { var total int64 for _, elem := range elems { @@ -92,6 +108,64 @@ func NewTransferTask( processing: make(map[string]TaskElementInfo), IgnoreErrors: ignoreErrors, failed: make(map[string]error), + uploading: make(map[string]int64), + uploadCancel: make(map[string]context.CancelFunc), } return task } + +func (t *Task) UpdateTargetPath(targetPath string) { + t.targetPathMu.Lock() + t.targetPath = targetPath + t.targetPathMu.Unlock() + t.pathVersion.Add(1) + t.uploaded.Store(0) + + t.uploadMu.Lock() + t.uploading = make(map[string]int64) + for _, cancel := range t.uploadCancel { + cancel() + } + t.uploadMu.Unlock() +} + +func (t *Task) currentTargetPath() string { + t.targetPathMu.RLock() + defer t.targetPathMu.RUnlock() + return t.targetPath +} + +func (t *Task) setPhase(phase string) { + t.phaseMu.Lock() + t.phase = phase + t.phaseMu.Unlock() +} + +func (t *Task) registerUploadCancel(id string, cancel context.CancelFunc) { + t.uploadMu.Lock() + defer t.uploadMu.Unlock() + t.uploadCancel[id] = cancel +} + +func (t *Task) clearUploadCancel(id string) { + t.uploadMu.Lock() + defer t.uploadMu.Unlock() + delete(t.uploadCancel, id) + delete(t.uploading, id) +} + +func (t *Task) setUploadingBytes(id string, n int64) { + t.uploadMu.Lock() + defer t.uploadMu.Unlock() + t.uploading[id] = n +} + +func (t *Task) uploadedInFlight() int64 { + t.uploadMu.RLock() + defer t.uploadMu.RUnlock() + var total int64 + for _, n := range t.uploading { + total += n + } + return total +} diff --git a/core/tasks/transfer/taskinfo.go b/core/tasks/transfer/taskinfo.go index f9ec571e..e4d4e80a 100644 --- a/core/tasks/transfer/taskinfo.go +++ b/core/tasks/transfer/taskinfo.go @@ -26,18 +26,28 @@ func (e *TaskElement) SourceStorageName() string { type TaskInfo interface { TaskID() string TotalSize() int64 + Downloaded() int64 Uploaded() int64 Count() int Processing() []TaskElementInfo FailedFiles() []string + Phase() string + SourceStorageName() string + SourcePath() string + TargetStorageName() string + TargetPath() string } func (t *Task) TotalSize() int64 { return t.totalSize } +func (t *Task) Downloaded() int64 { + return t.downloaded.Load() +} + func (t *Task) Uploaded() int64 { - return t.uploaded.Load() + return t.uploaded.Load() + t.uploadedInFlight() } func (t *Task) Count() int { @@ -71,3 +81,34 @@ func (t *Task) FailedFiles() []string { } return result } + +func (t *Task) Phase() string { + t.phaseMu.RLock() + defer t.phaseMu.RUnlock() + return t.phase +} + +func (t *Task) SourceStorageName() string { + if len(t.elems) == 0 { + return "" + } + return t.elems[0].SourceStorage.Name() +} + +func (t *Task) SourcePath() string { + if len(t.elems) == 0 { + return "" + } + return t.elems[0].SourcePath +} + +func (t *Task) TargetStorageName() string { + if len(t.elems) == 0 { + return "" + } + return t.elems[0].TargetStorage.Name() +} + +func (t *Task) TargetPath() string { + return t.currentTargetPath() +} diff --git a/database/db.go b/database/db.go index 71870eab..b4b50385 100644 --- a/database/db.go +++ b/database/db.go @@ -17,9 +17,15 @@ import ( var db *gorm.DB func Init(ctx context.Context) { + if err := Open(ctx); err != nil { + log.FromContext(ctx).Fatal("Database initialization failed", "error", err) + } +} + +func Open(ctx context.Context) error { logger := log.FromContext(ctx) if err := os.MkdirAll(filepath.Dir(config.C().DB.Path), 0755); err != nil { - logger.Fatal("Failed to create data directory: ", err) + return fmt.Errorf("failed to create data directory: %w", err) } var err error db, err = gorm.Open(GetDialect(config.C().DB.Path), &gorm.Config{ @@ -33,20 +39,28 @@ func Init(ctx context.Context) { PrepareStmt: true, }) if err != nil { - logger.Fatal("Failed to open database: ", err) + return fmt.Errorf("failed to open database: %w", err) } logger.Debug("Database connected") if err := db.AutoMigrate(&User{}, &Dir{}, &Rule{}, &WatchChat{}); err != nil { - logger.Fatal("Database migration failed; if upgrading from an old version, try deleting the database file and retrying", "error", err) + return fmt.Errorf("database migration failed; if upgrading from an old version, try deleting the database file and retrying: %w", err) } - if err := syncUsers(ctx); err != nil { - logger.Fatal("Failed to sync users:", err) + if err := SyncUsers(ctx); err != nil { + return fmt.Errorf("failed to sync users: %w", err) } logger.Debug("Database migrated") logger.Info("Database initialized") + return nil } -func syncUsers(ctx context.Context) error { +func Ready() bool { + return db != nil +} + +func SyncUsers(ctx context.Context) error { + if db == nil { + return fmt.Errorf("database is not initialized") + } logger := log.FromContext(ctx) dbUsers, err := GetAllUsers(ctx) if err != nil { diff --git a/database/storage_cleanup_test.go b/database/storage_cleanup_test.go index a1c62c29..438098f3 100644 --- a/database/storage_cleanup_test.go +++ b/database/storage_cleanup_test.go @@ -77,6 +77,65 @@ func TestClearStorageReferences(t *testing.T) { } } +func TestRenameStorageReferences(t *testing.T) { + ctx := context.Background() + openTestDB(t) + + user := User{ + ChatID: 1001, + DefaultStorage: "old", + Silent: true, + } + if err := db.WithContext(ctx).Create(&user).Error; err != nil { + t.Fatalf("failed to create user: %v", err) + } + dir := Dir{UserID: user.ID, StorageName: "old", Path: "/old"} + if err := db.WithContext(ctx).Create(&dir).Error; err != nil { + t.Fatalf("failed to create dir: %v", err) + } + user.DefaultDir = dir.ID + if err := db.WithContext(ctx).Save(&user).Error; err != nil { + t.Fatalf("failed to update user default dir: %v", err) + } + if err := db.WithContext(ctx).Create(&Rule{ + UserID: user.ID, + Type: rule.FileNameRegex.String(), + Data: ".*", + StorageName: "old", + DirPath: "/old", + }).Error; err != nil { + t.Fatalf("failed to create rule: %v", err) + } + + if err := RenameStorageReferences(ctx, "old", "new"); err != nil { + t.Fatalf("rename failed: %v", err) + } + + var gotUser User + if err := db.WithContext(ctx).First(&gotUser, user.ID).Error; err != nil { + t.Fatalf("failed to load user: %v", err) + } + if gotUser.DefaultStorage != "new" || gotUser.DefaultDir != dir.ID || !gotUser.Silent { + t.Fatalf("expected user defaults to move to new storage, got storage=%q dir=%d silent=%v", gotUser.DefaultStorage, gotUser.DefaultDir, gotUser.Silent) + } + + var gotDir Dir + if err := db.WithContext(ctx).First(&gotDir, dir.ID).Error; err != nil { + t.Fatalf("failed to load dir: %v", err) + } + if gotDir.StorageName != "new" { + t.Fatalf("expected dir storage to be renamed, got %q", gotDir.StorageName) + } + + var gotRule Rule + if err := db.WithContext(ctx).Where("user_id = ?", user.ID).First(&gotRule).Error; err != nil { + t.Fatalf("failed to load rule: %v", err) + } + if gotRule.StorageName != "new" { + t.Fatalf("expected rule storage to be renamed, got %q", gotRule.StorageName) + } +} + func openTestDB(t *testing.T) { t.Helper() var err error diff --git a/database/user.go b/database/user.go index 3b33366d..75ee70fe 100644 --- a/database/user.go +++ b/database/user.go @@ -89,3 +89,30 @@ func ClearStorageReferences(ctx context.Context, storageName string) error { return nil }) } + +func RenameStorageReferences(ctx context.Context, oldName, newName string) error { + if db == nil { + return fmt.Errorf("database is not initialized") + } + if oldName == "" || newName == "" || oldName == newName { + return nil + } + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&User{}). + Where("default_storage = ?", oldName). + Update("default_storage", newName).Error; err != nil { + return err + } + if err := tx.Model(&Dir{}). + Where("storage_name = ?", oldName). + Update("storage_name", newName).Error; err != nil { + return err + } + if err := tx.Model(&Rule{}). + Where("storage_name = ?", oldName). + Update("storage_name", newName).Error; err != nil { + return err + } + return nil + }) +} diff --git a/storage/load.go b/storage/load.go index 09bf4bf4..e5005ff7 100644 --- a/storage/load.go +++ b/storage/load.go @@ -2,7 +2,9 @@ package storage import ( "context" + "errors" "fmt" + "sync" "github.com/charmbracelet/log" "github.com/krau/SaveAny-Bot/config" @@ -10,6 +12,7 @@ import ( ) var UserStorages = make(map[int64][]Storage) +var storagesMu sync.RWMutex // GetStorageByName returns storage by name from cache or creates new one // It should NOT be used to get storage for user, use GetStorageByUserIDAndName instead @@ -18,10 +21,13 @@ func GetStorageByName(ctx context.Context, name string) (Storage, error) { return nil, ErrStorageNameEmpty } + storagesMu.RLock() storage, ok := Storages[name] + storagesMu.RUnlock() if ok { return storage, nil } + cfg := config.C().GetStorageByName(name) if cfg == nil { return nil, fmt.Errorf("未找到存储 %s", name) @@ -31,7 +37,14 @@ func GetStorageByName(ctx context.Context, name string) (Storage, error) { if err != nil { return nil, err } + + storagesMu.Lock() + if existing, ok := Storages[name]; ok { + storagesMu.Unlock() + return existing, nil + } Storages[name] = storage + storagesMu.Unlock() return storage, nil } @@ -52,9 +65,14 @@ func GetUserStorages(ctx context.Context, chatID int64) []Storage { if chatID <= 0 { return nil } + storagesMu.RLock() if storages, ok := UserStorages[chatID]; ok { - return storages + out := append([]Storage(nil), storages...) + storagesMu.RUnlock() + return out } + storagesMu.RUnlock() + var storages []Storage for _, name := range config.C().GetStorageNamesByUserID(chatID) { storage, err := GetStorageByName(ctx, name) @@ -63,22 +81,64 @@ func GetUserStorages(ctx context.Context, chatID int64) []Storage { } storages = append(storages, storage) } - return storages + storagesMu.Lock() + UserStorages[chatID] = storages + out := append([]Storage(nil), storages...) + storagesMu.Unlock() + return out } func LoadStorages(ctx context.Context) { + logger := log.FromContext(ctx) + if err := ReloadStorages(ctx); err != nil { + logger.Errorf("failed to load some storages: %v", err) + } +} + +func ReloadStorages(ctx context.Context) error { logger := log.FromContext(ctx) logger.Debug("loading storages...") - for _, storage := range config.C().Storages { - _, err := GetStorageByName(ctx, storage.GetName()) + + nextStorages := make(map[string]Storage) + var errs []error + for _, cfg := range config.C().Storages { + storage, err := NewStorage(ctx, cfg) if err != nil { - logger.Errorf("failed to load storage %s: %v", storage.GetName(), err) + errs = append(errs, fmt.Errorf("%s: %w", cfg.GetName(), err)) + logger.Errorf("failed to load storage %s: %v", cfg.GetName(), err) + continue } + nextStorages[cfg.GetName()] = storage } - logger.Infof("successfully loaded %d storages", len(Storages)) - for user := range config.C().GetUsersID() { - UserStorages[int64(user)] = GetUserStorages(ctx, int64(user)) + + nextUserStorages := make(map[int64][]Storage) + for _, userID := range config.C().GetUsersID() { + for _, name := range config.C().GetStorageNamesByUserID(userID) { + storage, ok := nextStorages[name] + if ok { + nextUserStorages[userID] = append(nextUserStorages[userID], storage) + } + } + } + + storagesMu.Lock() + Storages = nextStorages + UserStorages = nextUserStorages + storagesMu.Unlock() + + logger.Infof("successfully loaded %d storages", len(nextStorages)) + return errors.Join(errs...) +} + +func Snapshot() map[string]Storage { + storagesMu.RLock() + defer storagesMu.RUnlock() + + out := make(map[string]Storage, len(Storages)) + for name, storage := range Storages { + out[name] = storage } + return out } // GetTelegramStorageByUserID returns the first enabled Telegram storage for the user From 6ef6257bce2a45ee7000a39be6546e6bf44e0dc0 Mon Sep 17 00:00:00 2001 From: liyw0205 Date: Mon, 25 May 2026 12:40:37 +0800 Subject: [PATCH 3/3] ci(release): build android arm64 --- .github/workflows/build-release.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index 3d977913..9cc5cd86 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -38,6 +38,9 @@ jobs: matrix: goos: [linux, darwin, windows] goarch: [amd64, arm64] + include: + - goos: android + goarch: arm64 exclude: - goos: windows goarch: arm64