diff --git a/drivers/all.go b/drivers/all.go index fb68d0395..91b86d618 100644 --- a/drivers/all.go +++ b/drivers/all.go @@ -23,6 +23,7 @@ import ( _ "github.com/OpenListTeam/OpenList/v4/drivers/baidu_photo" _ "github.com/OpenListTeam/OpenList/v4/drivers/chaoxing" _ "github.com/OpenListTeam/OpenList/v4/drivers/chunk" + _ "github.com/OpenListTeam/OpenList/v4/drivers/cloudflare_imgbed" _ "github.com/OpenListTeam/OpenList/v4/drivers/cloudreve" _ "github.com/OpenListTeam/OpenList/v4/drivers/cloudreve_v4" _ "github.com/OpenListTeam/OpenList/v4/drivers/cnb_releases" diff --git a/drivers/cloudflare_imgbed/driver.go b/drivers/cloudflare_imgbed/driver.go new file mode 100644 index 000000000..52cef4391 --- /dev/null +++ b/drivers/cloudflare_imgbed/driver.go @@ -0,0 +1,128 @@ +package cloudflare_imgbed + +import ( + "context" + "fmt" + "net/http" + "path" + "strings" + + "github.com/OpenListTeam/OpenList/v4/drivers/base" + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type CFImgBed struct { + model.Storage + Addition + client *resty.Client +} + +func (d *CFImgBed) Config() driver.Config { return config } +func (d *CFImgBed) GetAddition() driver.Additional { return &d.Addition } + +func (d *CFImgBed) Init(ctx context.Context) error { + d.UploadThread = min(d.UploadThread, 32) + if d.UploadThread < 1 { + d.UploadThread = 3 + } + d.Address = strings.TrimRight(d.Address, "/") + + d.client = base.NewRestyClient(). + SetBaseURL(d.Address). + SetHeader("Authorization", "Bearer "+d.Token). + SetDebug(false) + + // 连通性测试:尝试获取根目录单条数据 + _, err := d.doRequest(http.MethodGet, listApi, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "start": "0", + "count": "1", + "dir": "/", + }) + }, nil) + if err != nil { + return fmt.Errorf("init verification failed: %w", err) + } + return nil +} + +func (d *CFImgBed) Drop(ctx context.Context) error { return nil } + +func (d *CFImgBed) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + reqPath := dir.GetPath() + + dirSeen := make(map[string]bool) + fileSeen := make(map[string]bool) + objs := make([]model.Obj, 0) + + start := 0 + for { + var resp ListResponse + _, err := d.doRequest(http.MethodGet, listApi, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "dir": reqPath, + "start": fmt.Sprintf("%d", start), + "count": fmt.Sprintf("%d", listPageSize), + }) + }, &resp) + if err != nil { + return nil, err + } + + for _, rawDir := range resp.Directories { + cleanDir := strings.TrimRight(rawDir, "/") + if !dirSeen[cleanDir] { + dirSeen[cleanDir] = true + objs = append(objs, &model.Object{ + Path: cleanDir, + Name: path.Base(cleanDir), + Modified: d.Modified, + IsFolder: true, + }) + } + } + + for _, item := range resp.Files { + if !fileSeen[item.Name] { + fileSeen[item.Name] = true + objs = append(objs, parseFile(item)) + } + } + + // 如果当前获取的数量少于分页大小,说明已加载完毕 + if len(resp.Files)+len(resp.Directories) < listPageSize { + break + } + start += listPageSize + } + return objs, nil +} + +func (d *CFImgBed) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + return &model.Link{URL: d.Address + "/file/" + utils.EncodePath(file.GetPath())}, nil +} + +// MakeDir 在图床中通常是虚拟的,此处返回虚拟目录对象以支持上传时的路径展示 +func (d *CFImgBed) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + fullPath := path.Join(parentDir.GetPath(), dirName) + return &model.Object{ + Path: fullPath, + Name: dirName, + IsFolder: true, + }, nil +} + +func (d *CFImgBed) Remove(ctx context.Context, obj model.Obj) error { + reqPath := obj.GetPath() + _, err := d.doRequest(http.MethodPost, deleteApi, func(req *resty.Request) { + req.SetBody(map[string]string{ + "path": reqPath, + }).SetQueryParam("folder", fmt.Sprintf("%t", obj.IsDir())) + }, nil) + return err +} + +var _ driver.Driver = (*CFImgBed)(nil) diff --git a/drivers/cloudflare_imgbed/meta.go b/drivers/cloudflare_imgbed/meta.go new file mode 100644 index 000000000..46285ca13 --- /dev/null +++ b/drivers/cloudflare_imgbed/meta.go @@ -0,0 +1,27 @@ +package cloudflare_imgbed + +import ( + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/op" +) + +type Addition struct { + driver.RootPath + Address string `json:"address" type:"text" required:"true" help:"Backend API address of the image hosting service, e.g., https://img.example.com"` + Token string `json:"token" type:"text" required:"true" help:"Authentication Token"` + SmallChannelName string `json:"smallChannelName" type:"text" help:"Channel name for regular files (typically <20MB)"` + LargeChannelName string `json:"largeChannelName" type:"text" help:"Channel name for large files"` + LargeChannelType string `json:"largeChannelType" type:"select" options:",huggingface" help:"Special type for large file channels (select 'huggingface' for direct upload to HuggingFace)"` + UploadThread int `json:"uploadThread" type:"number" default:"3" help:"Concurrent thread count for HuggingFace chunked direct upload"` +} + +var config = driver.Config{ + Name: "cloudflare_imgbed", + LocalSort: true, + NoUpload: false, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { return &CFImgBed{} }) +} diff --git a/drivers/cloudflare_imgbed/types.go b/drivers/cloudflare_imgbed/types.go new file mode 100644 index 000000000..d65a1f6fa --- /dev/null +++ b/drivers/cloudflare_imgbed/types.go @@ -0,0 +1,114 @@ +package cloudflare_imgbed + +import ( + "fmt" + "path" + "strconv" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/model" +) + +const listPageSize = 1000 + +// ListResponse 列表接口响应 +type ListResponse struct { + Files []FileItem `json:"files"` + Directories []string `json:"directories"` +} + +type FileItem struct { + Name string `json:"name"` + Metadata map[string]interface{} `json:"metadata"` // 存储文件大小、哈希、时间戳等 +} + +type apiError struct { + Error string `json:"error"` + Message string `json:"message"` +} + +// standardUploadResp 标准上传成功返回的数组 +type standardUploadResp []struct { + Src string `json:"src"` +} + +// hfGetUrlResp 获取 HF 直传授权地址的响应 +type hfGetUrlResp struct { + Success bool `json:"success"` + FullID string `json:"fullId"` + FilePath string `json:"filePath"` + ChannelName string `json:"channelName"` + Repo string `json:"repo"` + NeedsLfs bool `json:"needsLfs"` // 是否需要进行 LFS 物理上传 + AlreadyExists bool `json:"alreadyExists"` // 是否秒传成功 + Oid string `json:"oid"` // Git LFS 对象 ID (SHA256) + UploadAction *UploadAction `json:"uploadAction"` +} + +type UploadAction struct { + Href string `json:"href"` + Header map[string]string `json:"header"` +} + +type hfCommitResp struct { + Success bool `json:"success"` + Src string `json:"src"` + FileUrl string `json:"fileUrl"` + FullID string `json:"fullId"` +} + +// 辅助函数:从 map 中安全提取字符串/数值 +func getString(m map[string]interface{}, keys ...string) string { + for _, k := range keys { + if v, ok := m[k]; ok { + switch val := v.(type) { + case string: + return val + case float64: + return strconv.FormatInt(int64(val), 10) + default: + return fmt.Sprintf("%v", val) + } + } + } + return "" +} + +func getInt64(m map[string]interface{}, keys ...string) int64 { + for _, k := range keys { + if v, ok := m[k]; ok { + switch val := v.(type) { + case string: + n, _ := strconv.ParseInt(val, 10, 64) + return n + case float64: + return int64(val) + case int64: + return val + } + } + } + return 0 +} + +func parseFile(item FileItem) *model.Object { + name := path.Base(item.Name) + var size int64 + var modTime time.Time + + if item.Metadata != nil { + size = getInt64(item.Metadata, "FileSizeBytes", "File-Size") + ts := getInt64(item.Metadata, "TimeStamp") + if ts > 0 { + modTime = time.UnixMilli(ts) + } + } + + return &model.Object{ + Path: item.Name, + Name: name, + Size: size, + Modified: modTime, + IsFolder: false, + } +} diff --git a/drivers/cloudflare_imgbed/upload.go b/drivers/cloudflare_imgbed/upload.go new file mode 100644 index 000000000..b860c73f5 --- /dev/null +++ b/drivers/cloudflare_imgbed/upload.go @@ -0,0 +1,356 @@ +package cloudflare_imgbed + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/OpenListTeam/OpenList/v4/drivers/base" + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/avast/retry-go" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +func (d *CFImgBed) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + fileSize := file.GetSize() + // 如果文件较大且配置了 HuggingFace 渠道,走直传流程 + if fileSize >= hfDirectThreshold && d.LargeChannelType == "huggingface" { + log.WithField("size", fileSize).Debug("file exceeds threshold, using HuggingFace direct upload") + return d.hfDirectUpload(ctx, dstDir, file, up) + } + // 否则走普通图床 API 上传 + return d.standardUpload(ctx, dstDir, file, up) +} + +// standardUpload 通过普通 multipart 表单上传。 +// 使用 io.MultiReader 实现虚拟拼接,避免将整个大文件读入内存构建表单。 +func (d *CFImgBed) standardUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + + channelName := d.SmallChannelName + if file.GetSize() >= hfDirectThreshold { + channelName = d.LargeChannelName + log.WithField("size", file.GetSize()).Warn("File exceeds threshold but non-HF channel is used.") + } + if channelName == "" { + return nil, fmt.Errorf("channel name not configured") + } + + // 1. 将参数放入 Query String + reqUrl, _ := url.Parse(d.Address + uploadApi) + q := reqUrl.Query() + q.Set("uploadFolder", dstDir.GetPath()) + q.Set("returnFormat", "default") + q.Set("channelName", channelName) + reqUrl.RawQuery = q.Encode() + + // 2. 构建 multipart 表单的头部 + b := bytes.NewBuffer(make([]byte, 0, 164+len(file.GetName()))) // 预估头部大小,避免频繁扩容 + w := multipart.NewWriter(b) + _, err := w.CreateFormFile("file", file.GetName()) + if err != nil { + return nil, err + } + headSize := b.Len() + err = w.Close() + if err != nil { + return nil, err + } + head := bytes.NewReader(b.Bytes()[:headSize]) + tail := bytes.NewReader(b.Bytes()[headSize:]) + + // 3. 将 [表单头 + 文件流 + 表单尾] 组合成单一 Reader + rateLimitedReader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: &driver.SimpleReaderWithSize{ + Reader: io.MultiReader(head, file, tail), + Size: int64(b.Len()) + file.GetSize(), + }, + UpdateProgress: up, + }) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl.String(), rateLimitedReader) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", w.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+d.Token) + req.ContentLength = int64(b.Len()) + file.GetSize() + res, err := base.HttpClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + b.Reset() + _, err = b.ReadFrom(res.Body) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("upload failed %d: %s", res.StatusCode, b.String()) + } + + var resp standardUploadResp + if err := json.Unmarshal(b.Bytes(), &resp); err != nil { + return nil, err + } + if len(resp) == 0 || resp[0].Src == "" { + return nil, fmt.Errorf("no src returned") + } + + srcPath := strings.TrimPrefix(resp[0].Src, "/file/") + srcPath = strings.TrimPrefix(srcPath, "/") + + return &model.Object{ + Path: srcPath, + Name: file.GetName(), + Size: file.GetSize(), + Modified: file.ModTime(), + IsFolder: false, + }, nil +} + +// hfDirectUpload 处理 HuggingFace 的 LFS 直传逻辑(申请授权 -> 物理上传 -> 后端 Commit) +func (d *CFImgBed) hfDirectUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + channelName := d.LargeChannelName + if channelName == "" { + return nil, errors.New("LargeChannelName not configured") + } + + sha256Hash := file.GetHash().GetHash(utils.SHA256) + if len(sha256Hash) != utils.SHA256.Width { + var err error + _, sha256Hash, err = stream.CacheFullAndHash(file, &up, utils.SHA256) + if err != nil { + return nil, err + } + } + + fileSize := file.GetSize() + sampleSize := min(fileSize, fileSampleSize) + sampleRd, err := file.RangeRead(http_range.Range{Start: 0, Length: sampleSize}) + if err != nil { + return nil, err + } + sampleBuf := make([]byte, sampleSize) + _, err = io.ReadFull(sampleRd, sampleBuf) + if err != nil && err != io.EOF { + return nil, err + } + fileSample := base64.StdEncoding.EncodeToString(sampleBuf) + + fileMime := file.GetMimetype() + // 1. 请求图床后端获取 HF 授权地址 + reqBody := map[string]interface{}{ + "fileName": file.GetName(), + "fileType": fileMime, + "fileSize": fileSize, + "sha256": sha256Hash, + "fileSample": fileSample, + "channelName": channelName, + "uploadFolder": dstDir.GetPath(), + } + + var getUrlResp hfGetUrlResp + _, err = d.doRequest(http.MethodPost, hfGetUrlApi, func(req *resty.Request) { + req.SetBody(reqBody) + req.SetHeader("Content-Type", "application/json") + }, &getUrlResp) + if err != nil { + return nil, err + } + + // 秒传逻辑 + if getUrlResp.AlreadyExists || !getUrlResp.NeedsLfs { + return d.hfCommit(ctx, getUrlResp, file.GetName(), fileSize, fileMime, file.ModTime()) + } + + if getUrlResp.UploadAction == nil { + return nil, fmt.Errorf("HF upload action is nil") + } + + headers := getUrlResp.UploadAction.Header + href := getUrlResp.UploadAction.Href + + // 2. 根据响应判断是执行分片上传还是单文件上传 + chunkSizeStr, needChunk := headers["chunk_size"] + if needChunk { + // 分片直传 (AWS S3 Multipart 风格) + chunkSize, _ := strconv.ParseInt(chunkSizeStr, 10, 64) + if chunkSize <= 0 { + chunkSize = 20 * 1024 * 1024 + } + + partUrls := make(map[int]string) + for k, v := range headers { + if len(k) == 5 { // 格式通常为 "00001", "00002" + if idx, err := strconv.Atoi(k); err == nil { + partUrls[idx] = v + } + } + } + totalParts := len(partUrls) + + ss, err := stream.NewStreamSectionReader(file, int(chunkSize), &up) + if err != nil { + return nil, err + } + + g, uploadCtx := errgroup.NewOrderedGroupWithContext(ctx, min(d.UploadThread, totalParts), + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + + parts := make([]map[string]any, totalParts) + + for partNumber := range partUrls { + if utils.IsCanceled(uploadCtx) { + break + } + partUrl := partUrls[partNumber] + offset := int64(partNumber-1) * chunkSize + sizeToRead := chunkSize + if offset+sizeToRead > fileSize { + sizeToRead = fileSize - offset + } + + var reader io.ReadSeeker + g.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) (err error) { + reader, err = ss.GetSectionReader(offset, sizeToRead) + return + }, + After: func(err error) { + ss.FreeSectionReader(reader) + }, + Do: func(ctx context.Context) (err error) { + _, err = reader.Seek(0, io.SeekStart) + if err != nil { + return err + } + limitedReader := driver.NewLimitedUploadStream(ctx, reader) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, partUrl, limitedReader) + if err != nil { + return err + } + for key, val := range headers { + if len(key) != 5 && key != "chunk_size" { + req.Header.Set(key, val) + } + } + req.ContentLength = sizeToRead + + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return fmt.Errorf("chunk %d failed: %d", partNumber, res.StatusCode) + } + + etag := res.Header.Get("ETag") + parts[partNumber-1] = map[string]any{"partNumber": partNumber, "etag": etag} + + up(95 * float64(g.Success()+1) / float64(totalParts)) + return nil + }, + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + + // 合并分片 + // sort.Slice(parts, func(i, j int) bool { return parts[i]["partNumber"].(int) < parts[j]["partNumber"].(int) }) + mergeBody, _ := json.Marshal(map[string]any{"oid": getUrlResp.Oid, "parts": parts}) + mergeReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, href, bytes.NewReader(mergeBody)) + mergeReq.Header.Set("Content-Type", "application/vnd.git-lfs+json") + for k, v := range headers { + if k != "chunk_size" && len(k) != 5 { + mergeReq.Header.Set(k, v) + } + } + res, err := base.HttpClient.Do(mergeReq) + if err != nil { + return nil, err + } + up(97) + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("merge chunks failed") + } + + } else { + // 单文件直传 (PUT) + limitedReader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ + Reader: file, + UpdateProgress: model.UpdateProgressWithRange(up, 0, 97), + }) + + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, href, limitedReader) + req.ContentLength = fileSize + for k, v := range headers { + req.Header.Set(k, v) + } + res, err := base.HttpClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("direct upload failed") + } + } + + defer up(100) + + // 3. 通知图床后端完成文件登记 + return d.hfCommit(ctx, getUrlResp, file.GetName(), fileSize, fileMime, file.ModTime()) +} + +func (d *CFImgBed) hfCommit(ctx context.Context, getUrlResp hfGetUrlResp, fileName string, fileSize int64, fileMime string, modTime time.Time) (model.Obj, error) { + commitBody := map[string]interface{}{ + "fullId": getUrlResp.FullID, + "filePath": getUrlResp.FilePath, + "sha256": getUrlResp.Oid, + "fileSize": fileSize, + "fileName": fileName, + "fileType": fileMime, + "channelName": getUrlResp.ChannelName, + } + var commitResp hfCommitResp + _, err := d.doRequest(http.MethodPost, hfCommitApi, func(req *resty.Request) { + req.SetBody(commitBody) + }, &commitResp) + if err != nil || !commitResp.Success { + return nil, fmt.Errorf("HF commit failed") + } + + srcPath := strings.TrimPrefix(commitResp.Src, "/file/") + srcPath = strings.TrimPrefix(srcPath, "/") + + return &model.Object{ + Path: srcPath, + Name: fileName, + Size: fileSize, + Modified: modTime, + IsFolder: false, + }, nil +} diff --git a/drivers/cloudflare_imgbed/util.go b/drivers/cloudflare_imgbed/util.go new file mode 100644 index 000000000..7a6b50502 --- /dev/null +++ b/drivers/cloudflare_imgbed/util.go @@ -0,0 +1,67 @@ +package cloudflare_imgbed + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +const ( + listApi = "/api/manage/list" + deleteApi = "/api/manage/delete" + uploadApi = "/upload" + hfGetUrlApi = "/upload/huggingface/getUploadUrl" + hfCommitApi = "/upload/huggingface/commitUpload" + hfDirectThreshold int64 = 20 * 1024 * 1024 + fileSampleSize = 512 // HF 申请上传地址时需提供文件前 512 字节的 Sample +) + +// doRequest 通用请求封装,包含重试和 API 错误解析 +func (d *CFImgBed) doRequest(method, urlPath string, callback func(*resty.Request), resp interface{}) ([]byte, error) { + maxRetries := 3 + for i := 0; i < maxRetries; i++ { + req := d.client.R() + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + + res, err := req.Execute(method, urlPath) + if err != nil { + log.WithError(err).Warnf("request %s %s failed, attempt %d/%d", method, urlPath, i+1, maxRetries) + if i < maxRetries-1 { + time.Sleep(time.Duration(i+1) * time.Second) + continue + } + return nil, err + } + + body := res.Body() + var apiErr apiError + if err := json.Unmarshal(body, &apiErr); err == nil { + if apiErr.Error != "" || apiErr.Message != "" { + msg := apiErr.Error + if msg == "" { + msg = apiErr.Message + } + return nil, fmt.Errorf("API error: %s", msg) + } + } + + if res.StatusCode() == 429 { + time.Sleep(time.Duration(i+1) * 2 * time.Second) + continue + } + + if res.IsError() { + return nil, fmt.Errorf("HTTP %d", res.StatusCode()) + } + return body, nil + } + return nil, fmt.Errorf("max retries exceeded") +}