diff --git a/.github/workflows/test_docker.yml b/.github/workflows/test_docker.yml index aa6fe8966..b46413841 100644 --- a/.github/workflows/test_docker.yml +++ b/.github/workflows/test_docker.yml @@ -1,5 +1,4 @@ name: Beta Release (Docker) - on: workflow_dispatch: push: @@ -7,51 +6,51 @@ on: - main pull_request: branches: - - main + - fix # 👈 允许你的 fix 分支触发 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true env: - DOCKERHUB_ORG_NAME: ${{ vars.DOCKERHUB_ORG_NAME || 'openlistteam' }} - GHCR_ORG_NAME: ${{ vars.GHCR_ORG_NAME || 'openlistteam' }} - IMAGE_NAME: openlist-git - IMAGE_NAME_DOCKERHUB: openlist + GHCR_ORG_NAME: ${{ vars.GHCR_ORG_NAME || 'ironboxplus' }} # 👈 最好改成你的用户名,防止推错地方 + IMAGE_NAME: openlist REGISTRY: ghcr.io ARTIFACT_NAME: 'binaries_docker_release' - RELEASE_PLATFORMS: 'linux/amd64,linux/arm64,linux/arm/v7,linux/386,linux/arm/v6,linux/ppc64le,linux/riscv64,linux/loong64' ### Temporarily disable Docker builds for linux/s390x architectures for unknown reasons. - IMAGE_PUSH: ${{ github.event_name == 'push' }} + # 👇 关键修改:只保留 linux/amd64,删掉后面一长串 + RELEASE_PLATFORMS: 'linux/amd64' + # 👇 关键修改:强制允许推送,不用管是不是 push 事件 + IMAGE_PUSH: 'true' IMAGE_TAGS_BETA: | type=ref,event=pr - type=raw,value=beta,enable={{is_default_branch}} + type=raw,value=beta-baidu jobs: build_binary: - name: Build Binaries for Docker Release (Beta) + name: Build Binaries (x64 Only) runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 - - uses: actions/setup-go@v5 with: go-version: '1.25.0' + # 即使只构建 x64,我们也需要 musl 工具链(因为 BuildDockerMultiplatform 默认会检查它) - name: Cache Musl id: cache-musl uses: actions/cache@v4 with: path: build/musl-libs key: docker-musl-libs-v2 - - name: Download Musl Library if: steps.cache-musl.outputs.cache-hit != 'true' run: bash build.sh prepare docker-multiplatform env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Build go binary (beta) + - name: Build go binary + # 这里还是跑 docker-multiplatform,虽然会多编译一些架构,但这是兼容 Dockerfile 路径最稳妥的方法 run: bash build.sh beta docker-multiplatform env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -69,12 +68,13 @@ jobs: release_docker: needs: build_binary - name: Release Docker image (Beta) + name: Release Docker (x64) runs-on: ubuntu-latest permissions: packages: write strategy: matrix: + # 你可以选择只构建 latest,或者保留全部变体 image: ["latest", "ffmpeg", "aria2", "aio"] include: - image: "latest" @@ -102,46 +102,32 @@ jobs: with: name: ${{ env.ARTIFACT_NAME }} path: 'build/' - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + # 👇 只保留 GitHub 登录,删除了 DockerHub 登录 - name: Login to GitHub Container Registry - if: env.IMAGE_PUSH == 'true' uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Login to DockerHub Container Registry - if: env.IMAGE_PUSH == 'true' - uses: docker/login-action@v3 - with: - username: ${{ vars.DOCKERHUB_ORG_NAME_BACKUP || env.DOCKERHUB_ORG_NAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Docker meta id: meta uses: docker/metadata-action@v5 with: images: | ${{ env.REGISTRY }}/${{ env.GHCR_ORG_NAME }}/${{ env.IMAGE_NAME }} - ${{ env.DOCKERHUB_ORG_NAME }}/${{ env.IMAGE_NAME_DOCKERHUB }} tags: ${{ env.IMAGE_TAGS_BETA }} - flavor: | - ${{ matrix.tag_favor }} + flavor: ${{ matrix.tag_favor }} - name: Build and push - id: docker_build uses: docker/build-push-action@v6 with: context: . file: Dockerfile.ci - push: ${{ env.IMAGE_PUSH == 'true' }} + push: true build-args: | BASE_IMAGE_TAG=${{ matrix.base_image_tag }} ${{ matrix.build_arg }} diff --git a/build.sh b/build.sh index 26e5a301b..0e8f4b85d 100644 --- a/build.sh +++ b/build.sh @@ -186,8 +186,8 @@ BuildDockerMultiplatform() { docker_lflags="--extldflags '-static -fpic' $ldflags" export CGO_ENABLED=1 - OS_ARCHES=(linux-amd64 linux-arm64 linux-386 linux-riscv64 linux-ppc64le linux-loong64) ## Disable linux-s390x builds - CGO_ARGS=(x86_64-linux-musl-gcc aarch64-linux-musl-gcc i486-linux-musl-gcc riscv64-linux-musl-gcc powerpc64le-linux-musl-gcc loongarch64-linux-musl-gcc) ## Disable s390x-linux-musl-gcc builds + OS_ARCHES=(linux-amd64) ## Disable linux-s390x builds + CGO_ARGS=(x86_64-linux-musl-gcc) ## Disable s390x-linux-musl-gcc builds for i in "${!OS_ARCHES[@]}"; do os_arch=${OS_ARCHES[$i]} cgo_cc=${CGO_ARGS[$i]} @@ -205,14 +205,14 @@ BuildDockerMultiplatform() { GO_ARM=(6 7) export GOOS=linux export GOARCH=arm - for i in "${!DOCKER_ARM_ARCHES[@]}"; do - docker_arch=${DOCKER_ARM_ARCHES[$i]} - cgo_cc=${CGO_ARGS[$i]} - export GOARM=${GO_ARM[$i]} - export CC=${cgo_cc} - echo "building for $docker_arch" - go build -o build/${docker_arch%%-*}/${docker_arch##*-}/"$appName" -ldflags="$docker_lflags" -tags=jsoniter . - done + # for i in "${!DOCKER_ARM_ARCHES[@]}"; do + # docker_arch=${DOCKER_ARM_ARCHES[$i]} + # cgo_cc=${CGO_ARGS[$i]} + # export GOARM=${GO_ARM[$i]} + # export CC=${cgo_cc} + # echo "building for $docker_arch" + # go build -o build/${docker_arch%%-*}/${docker_arch##*-}/"$appName" -ldflags="$docker_lflags" -tags=jsoniter . + # done } BuildRelease() { diff --git a/drivers/123_open/driver.go b/drivers/123_open/driver.go index ac75e51d7..90112b37a 100644 --- a/drivers/123_open/driver.go +++ b/drivers/123_open/driver.go @@ -2,7 +2,9 @@ package _123_open import ( "context" + "encoding/hex" "fmt" + "io" "strconv" "time" @@ -10,7 +12,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" - "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" ) @@ -156,20 +158,60 @@ func (d *Open123) Remove(ctx context.Context, obj model.Obj) error { } func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - // 1. 创建文件 + // 1. 准备参数 // parentFileID 父目录id,上传到根目录时填写 0 parentFileId, err := strconv.ParseInt(dstDir.GetID(), 10, 64) if err != nil { return nil, fmt.Errorf("parse parentFileID error: %v", err) } + // etag 文件md5 etag := file.GetHash().GetHash(utils.MD5) - if len(etag) < utils.MD5.Width { - _, etag, err = stream.CacheFullAndHash(file, &up, utils.MD5) + if len(etag) >= utils.MD5.Width { + // 有etag时,先尝试秒传 + createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) if err != nil { return nil, err } + // 是否秒传 + if createResp.Data.Reuse { + // 秒传成功才会返回正确的 FileID,否则为 0 + if createResp.Data.FileID != 0 { + return File{ + FileName: file.GetName(), + Size: file.GetSize(), + FileId: createResp.Data.FileID, + Type: 2, + Etag: etag, + }, nil + } + } + // 秒传失败,etag可能不可靠,继续流式计算真实MD5 } + + // 流式计算MD5 + md5Hash := utils.MD5.NewFunc() + size := file.GetSize() + chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk for MD5 calculation + var offset int64 = 0 + for offset < size { + readSize := min(chunkSize, size-offset) + reader, err := file.RangeRead(http_range.Range{Start: offset, Length: readSize}) + if err != nil { + return nil, fmt.Errorf("range read for MD5 calculation failed: %w", err) + } + if _, err := io.Copy(md5Hash, reader); err != nil { + return nil, fmt.Errorf("calculate MD5 failed: %w", err) + } + offset += readSize + + progress := 40 * float64(offset) / float64(size) + up(progress) + } + + etag = hex.EncodeToString(md5Hash.Sum(nil)) + + // 2. 创建上传任务 createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) if err != nil { return nil, err @@ -188,13 +230,16 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre } } - // 2. 上传分片 - err = d.Upload(ctx, file, createResp, up) + // 3. 上传分片 + uploadProgress := func(p float64) { + up(40 + p*0.6) + } + err = d.Upload(ctx, file, createResp, uploadProgress) if err != nil { return nil, err } - // 3. 上传完毕 + // 4. 合并分片/完成上传 for range 60 { uploadCompleteResp, err := d.complete(createResp.Data.PreuploadID) // 返回错误代码未知,如:20103,文档也没有具体说 diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index fe77aca38..7021f21c9 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -1,30 +1,18 @@ package baidu_netdisk import ( - "bytes" "context" - "crypto/md5" - "encoding/hex" "errors" - "io" - "mime/multipart" - "net/http" "net/url" - "os" stdpath "path" "strconv" - "strings" "time" "github.com/OpenListTeam/OpenList/v4/drivers/base" - "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/driver" - "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" - "github.com/OpenListTeam/OpenList/v4/internal/net" - "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/avast/retry-go" log "github.com/sirupsen/logrus" ) @@ -199,80 +187,26 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return newObj, nil } - var ( - cache = stream.GetFile() - tmpF *os.File - err error - ) - if cache == nil { - tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return nil, err - } - defer func() { - _ = tmpF.Close() - _ = os.Remove(tmpF.Name()) - }() - cache = tmpF - } - streamSize := stream.GetSize() sliceSize := d.getSliceSize(streamSize) count := 1 if streamSize > sliceSize { count = int((streamSize + sliceSize - 1) / sliceSize) } - lastBlockSize := streamSize % sliceSize - if lastBlockSize == 0 { - lastBlockSize = sliceSize - } - - // cal md5 for first 256k data - const SliceSize int64 = 256 * utils.KB - blockList := make([]string, 0, count) - byteSize := sliceSize - fileMd5H := md5.New() - sliceMd5H := md5.New() - sliceMd5H2 := md5.New() - slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) - writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write} - if tmpF != nil { - writers = append(writers, tmpF) - } - written := int64(0) - for i := 1; i <= count; i++ { - if utils.IsCanceled(ctx) { - return nil, ctx.Err() - } - if i == count { - byteSize = lastBlockSize - } - n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize) - written += n - if err != nil && err != io.EOF { - return nil, err - } - blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil))) - sliceMd5H.Reset() - } - if tmpF != nil { - if written != streamSize { - return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize) - } - _, err = tmpF.Seek(0, io.SeekStart) - if err != nil { - return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") - } - } - contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) - sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) - blockListStr, _ := utils.Json.MarshalToString(blockList) path := stdpath.Join(dstDir.GetPath(), stream.GetName()) mtime := stream.ModTime().Unix() ctime := stream.CreateTime().Unix() - // step.1 尝试读取已保存进度 + // step.1 流式计算MD5哈希值 + contentMd5, sliceMd5, blockList, ss, err := d.calculateHashesStream(ctx, stream, sliceSize, &up) + if err != nil { + return nil, err + } + + blockListStr, _ := utils.Json.MarshalToString(blockList) + + // step.2 尝试读取已保存进度或执行预上传 precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5) if !ok { // 没有进度,走预上传 @@ -288,6 +222,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return fileToObj(precreateResp.File), nil } } + ensureUploadURL := func() { if precreateResp.UploadURL != "" { return @@ -295,58 +230,24 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid) } - // step.2 上传分片 + // step.3 流式上传分片 + // 由于流式上传已经消耗了流,需要重新创建 StreamSectionReader + // 如果有缓存文件,可以直接使用;否则需要通过 RangeRead 重新获取 + if ss == nil || stream.GetFile() == nil { + // 重新创建 StreamSectionReader + ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err + } + } + uploadLoop: for range 2 { // 获取上传域名 ensureUploadURL() - // 并发上传 - threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, - retry.Attempts(UPLOAD_RETRY_COUNT), - retry.Delay(UPLOAD_RETRY_WAIT_TIME), - retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME), - retry.DelayType(retry.BackOffDelay), - retry.RetryIf(func(err error) bool { - return !errors.Is(err, ErrUploadIDExpired) - }), - retry.LastErrorOnly(true)) - - totalParts := len(precreateResp.BlockList) - - for i, partseq := range precreateResp.BlockList { - if utils.IsCanceled(upCtx) { - break - } - if partseq < 0 { - continue - } - i, partseq := i, partseq - offset, size := int64(partseq)*sliceSize, sliceSize - if partseq+1 == count { - size = lastBlockSize - } - threadG.Go(func(ctx context.Context) error { - params := map[string]string{ - "method": "upload", - "access_token": d.AccessToken, - "type": "tmpfile", - "path": path, - "uploadid": precreateResp.Uploadid, - "partseq": strconv.Itoa(partseq), - } - section := io.NewSectionReader(cache, offset, size) - err := d.uploadSlice(ctx, precreateResp.UploadURL, params, stream.GetName(), section) - if err != nil { - return err - } - precreateResp.BlockList[i] = -1 - progress := float64(threadG.Success()+1) * 100 / float64(totalParts+1) - up(progress) - return nil - }) - } - err = threadG.Wait() + // 流式并发上传 + err = d.uploadChunksStream(ctx, ss, stream, precreateResp, path, sliceSize, count, up) if err == nil { break uploadLoop } @@ -372,13 +273,19 @@ uploadLoop: precreateResp.UploadURL = "" // 覆盖掉旧的进度 base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + + // 尝试重新创建 StreamSectionReader(如果流支持重新读取) + ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err + } continue uploadLoop } return nil, err } defer up(100) - // step.3 创建文件 + // step.4 创建文件 var newFile File _, err = d.create(path, streamSize, 0, precreateResp.Uploadid, blockListStr, &newFile, mtime, ctime) if err != nil { @@ -427,67 +334,6 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in return &precreateResp, nil } -func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file *io.SectionReader) error { - b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) - mw := multipart.NewWriter(b) - _, err := mw.CreateFormFile("file", fileName) - if err != nil { - return err - } - headSize := b.Len() - err = mw.Close() - if err != nil { - return err - } - head := bytes.NewReader(b.Bytes()[:headSize]) - tail := bytes.NewReader(b.Bytes()[headSize:]) - rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, file, tail)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd) - if err != nil { - return err - } - query := req.URL.Query() - for k, v := range params { - query.Set(k, v) - } - req.URL.RawQuery = query.Encode() - req.Header.Set("Content-Type", mw.FormDataContentType()) - req.ContentLength = int64(b.Len()) + file.Size() - - client := net.NewHttpClient() - if d.UploadSliceTimeout > 0 { - client.Timeout = time.Second * time.Duration(d.UploadSliceTimeout) - } else { - client.Timeout = DEFAULT_UPLOAD_SLICE_TIMEOUT - } - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - b.Reset() - _, err = b.ReadFrom(resp.Body) - if err != nil { - return err - } - body := b.Bytes() - respStr := string(body) - log.Debugln(respStr) - lower := strings.ToLower(respStr) - // 合并 uploadid 过期检测逻辑 - if strings.Contains(lower, "uploadid") && - (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) { - return ErrUploadIDExpired - } - - errCode := utils.Json.Get(body, "error_code").ToInt() - errNo := utils.Json.Get(body, "errno").ToInt() - if errCode != 0 || errNo != 0 { - return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr) - } - return nil -} func (d *BaiduNetdisk) GetDetails(ctx context.Context) (*model.StorageDetails, error) { du, err := d.quota(ctx) diff --git a/drivers/baidu_netdisk/types.go b/drivers/baidu_netdisk/types.go index 03e84b396..35886ce76 100644 --- a/drivers/baidu_netdisk/types.go +++ b/drivers/baidu_netdisk/types.go @@ -7,7 +7,6 @@ import ( "time" "github.com/OpenListTeam/OpenList/v4/internal/model" - "github.com/OpenListTeam/OpenList/v4/pkg/utils" ) var ( @@ -76,9 +75,7 @@ func fileToObj(f File) *model.ObjThumb { Modified: time.Unix(f.ServerMtime, 0), Ctime: time.Unix(f.ServerCtime, 0), IsFolder: f.Isdir == 1, - - // 直接获取的MD5是错误的 - HashInfo: utils.NewHashInfo(utils.MD5, DecryptMd5(f.Md5)), + // 百度API返回的MD5不可信,不使用HashInfo }, Thumbnail: model.Thumbnail{Thumbnail: f.Thumbs.Url3}, } diff --git a/drivers/baidu_netdisk/upload.go b/drivers/baidu_netdisk/upload.go new file mode 100644 index 000000000..cc94d838e --- /dev/null +++ b/drivers/baidu_netdisk/upload.go @@ -0,0 +1,262 @@ +package baidu_netdisk + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/hex" + "errors" + "io" + "mime/multipart" + "net/http" + "strconv" + "strings" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/net" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/avast/retry-go" + log "github.com/sirupsen/logrus" +) + +// calculateHashesStream 流式计算文件的MD5哈希值 +// 返回:文件MD5、前256KB的MD5、每个分片的MD5列表、StreamSectionReader +func (d *BaiduNetdisk) calculateHashesStream( + ctx context.Context, + stream model.FileStreamer, + sliceSize int64, + up *driver.UpdateProgress, +) (contentMd5 string, sliceMd5 string, blockList []string, ss streamPkg.StreamSectionReaderIF, err error) { + streamSize := stream.GetSize() + count := 1 + if streamSize > sliceSize { + count = int((streamSize + sliceSize - 1) / sliceSize) + } + lastBlockSize := streamSize % sliceSize + if lastBlockSize == 0 { + lastBlockSize = sliceSize + } + + // 创建 StreamSectionReader 用于流式读取 + ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), nil) + if err != nil { + return "", "", nil, nil, err + } + + // 前256KB的MD5 + const SliceSize int64 = 256 * utils.KB + blockList = make([]string, 0, count) + fileMd5H := md5.New() + sliceMd5H2 := md5.New() + sliceWritten := int64(0) + + for i := 0; i < count; i++ { + if utils.IsCanceled(ctx) { + return "", "", nil, nil, ctx.Err() + } + + offset := int64(i) * sliceSize + length := sliceSize + if i == count-1 { + length = lastBlockSize + } + + reader, err := ss.GetSectionReader(offset, length) + if err != nil { + return "", "", nil, nil, err + } + + // 计算分片MD5 + sliceMd5Calc := md5.New() + + // 同时写入多个哈希计算器 + writers := []io.Writer{fileMd5H, sliceMd5Calc} + if sliceWritten < SliceSize { + remaining := SliceSize - sliceWritten + writers = append(writers, utils.LimitWriter(sliceMd5H2, remaining)) + } + + reader.Seek(0, io.SeekStart) + n, err := io.Copy(io.MultiWriter(writers...), reader) + if err != nil { + ss.FreeSectionReader(reader) + return "", "", nil, nil, err + } + sliceWritten += n + + blockList = append(blockList, hex.EncodeToString(sliceMd5Calc.Sum(nil))) + ss.FreeSectionReader(reader) + + // 更新进度(哈希计算占总进度的一小部分) + if up != nil { + progress := float64(i+1) * 10 / float64(count) + (*up)(progress) + } + } + + return hex.EncodeToString(fileMd5H.Sum(nil)), + hex.EncodeToString(sliceMd5H2.Sum(nil)), + blockList, ss, nil +} + +// uploadChunksStream 流式上传所有分片 +func (d *BaiduNetdisk) uploadChunksStream( + ctx context.Context, + ss streamPkg.StreamSectionReaderIF, + stream model.FileStreamer, + precreateResp *PrecreateResp, + path string, + sliceSize int64, + count int, + up driver.UpdateProgress, +) error { + streamSize := stream.GetSize() + lastBlockSize := streamSize % sliceSize + if lastBlockSize == 0 { + lastBlockSize = sliceSize + } + + // 使用 OrderedGroup 保证 Before 阶段有序 + thread := min(d.uploadThread, len(precreateResp.BlockList)) + threadG, upCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, + retry.Attempts(UPLOAD_RETRY_COUNT), + retry.Delay(UPLOAD_RETRY_WAIT_TIME), + retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME), + retry.DelayType(retry.BackOffDelay), + retry.RetryIf(func(err error) bool { + return !errors.Is(err, ErrUploadIDExpired) + }), + retry.LastErrorOnly(true)) + + totalParts := len(precreateResp.BlockList) + + for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(upCtx) { + break + } + if partseq < 0 { + continue + } + + i, partseq := i, partseq + offset := int64(partseq) * sliceSize + size := sliceSize + if partseq+1 == count { + size = lastBlockSize + } + + var reader io.ReadSeeker + + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + var err error + reader, err = ss.GetSectionReader(offset, size) + return err + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + err := d.uploadSliceStream(ctx, precreateResp.UploadURL, path, + precreateResp.Uploadid, partseq, stream.GetName(), reader, size) + if err != nil { + return err + } + precreateResp.BlockList[i] = -1 + // 进度从10%开始(前10%是哈希计算) + progress := 10 + float64(threadG.Success()+1)*90/float64(totalParts+1) + up(progress) + return nil + }, + After: func(err error) { + ss.FreeSectionReader(reader) + }, + }) + } + + return threadG.Wait() +} + +// uploadSliceStream 上传单个分片(接受io.ReadSeeker) +func (d *BaiduNetdisk) uploadSliceStream( + ctx context.Context, + uploadUrl string, + path string, + uploadid string, + partseq int, + fileName string, + reader io.ReadSeeker, + size int64, +) error { + params := map[string]string{ + "method": "upload", + "access_token": d.AccessToken, + "type": "tmpfile", + "path": path, + "uploadid": uploadid, + "partseq": strconv.Itoa(partseq), + } + + b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) + mw := multipart.NewWriter(b) + _, err := mw.CreateFormFile("file", fileName) + if err != nil { + return err + } + headSize := b.Len() + err = mw.Close() + if err != nil { + return err + } + head := bytes.NewReader(b.Bytes()[:headSize]) + tail := bytes.NewReader(b.Bytes()[headSize:]) + rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, reader, tail)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd) + if err != nil { + return err + } + query := req.URL.Query() + for k, v := range params { + query.Set(k, v) + } + req.URL.RawQuery = query.Encode() + req.Header.Set("Content-Type", mw.FormDataContentType()) + req.ContentLength = int64(b.Len()) + size + + client := net.NewHttpClient() + if d.UploadSliceTimeout > 0 { + client.Timeout = time.Second * time.Duration(d.UploadSliceTimeout) + } else { + client.Timeout = DEFAULT_UPLOAD_SLICE_TIMEOUT + } + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + b.Reset() + _, err = b.ReadFrom(resp.Body) + if err != nil { + return err + } + body := b.Bytes() + respStr := string(body) + log.Debugln(respStr) + lower := strings.ToLower(respStr) + // 合并 uploadid 过期检测逻辑 + if strings.Contains(lower, "uploadid") && + (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) { + return ErrUploadIDExpired + } + + errCode := utils.Json.Get(body, "error_code").ToInt() + errNo := utils.Json.Get(body, "errno").ToInt() + if errCode != 0 || errNo != 0 { + return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr) + } + return nil +} \ No newline at end of file diff --git a/plans/baidu_netdisk_streaming_upload_design.md b/plans/baidu_netdisk_streaming_upload_design.md new file mode 100644 index 000000000..f23e407ff --- /dev/null +++ b/plans/baidu_netdisk_streaming_upload_design.md @@ -0,0 +1,815 @@ + +# 百度网盘流式上传重构方案 + +## 1. 背景与问题分析 + +### 1.1 当前实现问题 + +当前百度网盘上传实现位于 [`drivers/baidu_netdisk/driver.go`](drivers/baidu_netdisk/driver.go:191-393),存在以下问题: + +**问题代码(第202-217行):** +```go +var ( + cache = stream.GetFile() + tmpF *os.File + err error +) +if cache == nil { + tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") + // ... 创建临时文件缓存整个上传文件 + cache = tmpF +} +``` + +**根本原因:** +- 第337行使用 `io.NewSectionReader(cache, offset, size)` 需要 `io.ReaderAt` 接口 +- 为了支持并发上传分片,需要能够随机访问文件的任意位置 +- 当前解决方案是将整个文件缓存到本地临时文件 + +**影响:** +- 需要本地存储空间 ≥ 最大上传文件大小 +- 对于大文件上传,会占用大量磁盘空间 +- 在存储空间有限的环境下无法上传大文件 + +### 1.2 成功案例分析 + +#### 123open 流式上传机制([`drivers/123_open/upload.go`](drivers/123_open/upload.go:45-171)) + +```go +// 核心实现 +ss, err := stream.NewStreamSectionReader(file, int(chunkSize), &up) + +threadG, uploadCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, ...) + +for partIndex := range uploadNums { + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) (err error) { + reader, err = ss.GetSectionReader(offset, size) // 获取分片Reader + return + }, + Do: func(ctx context.Context) (err error) { + // 执行上传 + return nil + }, + After: func(err error) { + ss.FreeSectionReader(reader) // 释放Reader + }, + }) +} +``` + +**关键特点:** +1. 使用 `NewStreamSectionReader` 实现流式读取 +2. 使用 `NewOrderedGroupWithContext` 保证 Before 阶段有序执行 +3. 三阶段生命周期:Before(获取Reader)→ Do(上传)→ After(释放Reader) +4. 流式表单构建:`io.MultiReader(head, reader, tail)` + +#### 115open 流式上传机制([`drivers/115_open/upload.go`](drivers/115_open/upload.go:72-149)) + +```go +ss, err := streamPkg.NewStreamSectionReader(stream, int(chunkSize), &up) + +for i := int64(1); i <= partNum; i++ { + rd, err := ss.GetSectionReader(offset, partSize) + err = retry.Do(func() error { + rd.Seek(0, io.SeekStart) + part, err := bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, rd), partSize, int(i)) + // ... + }, ...) + ss.FreeSectionReader(rd) +} +``` + +**关键特点:** +1. 顺序上传分片(for循环 + retry.Do) +2. 每个分片上传完成后立即释放 +3. 使用 OSS SDK 上传 + +### 1.3 StreamSectionReader 机制分析 + +[`internal/stream/util.go`](internal/stream/util.go:224-381) 中的 `StreamSectionReaderIF` 接口: + +```go +type StreamSectionReaderIF interface { + GetSectionReader(off, length int64) (io.ReadSeeker, error) // 获取分片Reader + FreeSectionReader(sr io.ReadSeeker) // 释放分片Reader + DiscardSection(off int64, length int64) error // 丢弃分片数据 +} +``` + +**实现类型:** +1. `cachedSectionReader` - 当文件已缓存时使用,直接返回 SectionReader +2. `directSectionReader` - 流式读取,使用内存缓冲池 + +**directSectionReader 工作原理:** +- 维护 `fileOffset` 跟踪当前读取位置 +- 使用 `bufPool` 缓冲池管理内存 +- `GetSectionReader` 从流中读取指定长度数据到缓冲区 +- `FreeSectionReader` 将缓冲区归还池中 + +**关键约束:** +- **线程不安全**:必须按顺序调用 `GetSectionReader` +- 请求的 offset 必须等于当前 fileOffset + +## 2. 整体架构设计 + +### 2.1 架构对比 + +``` +当前架构: +┌─────────────┐ ┌──────────────┐ ┌─────────────────┐ +│ FileStream │───>│ 临时文件缓存 │───>│ 并发分片上传 │ +└─────────────┘ │ (整个文件) │ │ (SectionReader) │ + └──────────────┘ └─────────────────┘ + +重构后架构: +┌─────────────┐ ┌────────────────────┐ ┌─────────────────┐ +│ FileStream │───>│ StreamSectionReader │───>│ 有序并发分片上传 │ +└─────────────┘ │ (按需缓存分片) │ │ (OrderedGroup) │ + └────────────────────┘ └─────────────────┘ +``` + +### 2.2 新架构流程图 + +```mermaid +flowchart TD + A[开始上传] --> B{文件大小检查} + B -->|空文件| C[返回错误] + B -->|非空| D[尝试秒传 PutRapid] + D -->|成功| E[返回文件对象] + D -->|失败| F[流式计算MD5和分片MD5] + F --> G[调用 precreate API] + G -->|return_type=2| E + G -->|return_type=1| H[创建 StreamSectionReader] + H --> I[创建 OrderedGroup] + I --> J[有序并发上传分片] + J --> K{上传结果} + K -->|成功| L[调用 create API 完成上传] + K -->|uploadid过期| M[重新 precreate] + M --> H + K -->|其他错误| N[保存进度并返回错误] + L --> E +``` + +### 2.3 分片上传详细流程 + +```mermaid +sequenceDiagram + participant Main as 主协程 + participant OG as OrderedGroup + participant G1 as Goroutine 1 + participant G2 as Goroutine 2 + participant SS as StreamSectionReader + participant API as 百度API + + Main->>OG: 创建 OrderedGroup + Main->>OG: GoWithLifecycle 分片1 + Main->>OG: GoWithLifecycle 分片2 + + Note over G1,SS: Before 阶段有序执行 + G1->>SS: GetSectionReader 分片1 + SS-->>G1: reader1 + G1->>OG: Before完成 释放锁 + + G2->>SS: GetSectionReader 分片2 + SS-->>G2: reader2 + G2->>OG: Before完成 释放锁 + + Note over G1,API: Do 阶段并发执行 + par 并发上传 + G1->>API: 上传分片1 + G2->>API: 上传分片2 + end + + API-->>G1: 成功 + API-->>G2: 成功 + + Note over G1,SS: After 阶段释放资源 + G1->>SS: FreeSectionReader reader1 + G2->>SS: FreeSectionReader reader2 +``` + +## 3. 关键技术点设计 + +### 3.1 MD5计算策略 + +**当前实现问题(第230-258行):** +- 需要遍历整个文件计算 MD5 +- 同时计算 fileMd5、sliceMd5、blockList +- 如果没有缓存文件,需要边读边写临时文件 + +**重构方案:** + +使用 [`stream.StreamHashFile()`](internal/stream/util.go:182-222) 进行流式哈希计算: + +```go +// 伪代码 +func calculateHashes(ctx context.Context, stream model.FileStreamer, sliceSize int64) (contentMd5, sliceMd5 string, blockList []string, err error) { + size := stream.GetSize() + count := (size + sliceSize - 1) / sliceSize + + // 创建 StreamSectionReader 用于流式读取 + ss, err := stream.NewStreamSectionReader(stream, int(sliceSize), nil) + if err != nil { + return "", "", nil, err + } + + fileMd5H := md5.New() + sliceMd5H := md5.New() // 前256KB + blockList = make([]string, 0, count) + + const SliceSize = 256 * 1024 + sliceWritten := int64(0) + + for i := 0; i < count; i++ { + offset := int64(i) * sliceSize + length := min(sliceSize, size-offset) + + reader, err := ss.GetSectionReader(offset, length) + if err != nil { + return "", "", nil, err + } + + // 计算分片MD5 + sliceMd5Calc := md5.New() + + // 同时写入多个哈希计算器 + writers := []io.Writer{fileMd5H, sliceMd5Calc} + if sliceWritten < SliceSize { + writers = append(writers, utils.LimitWriter(sliceMd5H, SliceSize-sliceWritten)) + } + + reader.Seek(0, io.SeekStart) + n, err := io.Copy(io.MultiWriter(writers...), reader) + if err != nil { + ss.FreeSectionReader(reader) + return "", "", nil, err + } + sliceWritten += n + + blockList = append(blockList, hex.EncodeToString(sliceMd5Calc.Sum(nil))) + ss.FreeSectionReader(reader) + } + + return hex.EncodeToString(fileMd5H.Sum(nil)), + hex.EncodeToString(sliceMd5H.Sum(nil)), + blockList, nil +} +``` + +**关键改进:** +1. 使用 StreamSectionReader 按分片读取 +2. 每个分片读取后立即释放内存 +3. 同时计算所有需要的哈希值 +4. 不需要创建临时文件 + +### 3.2 分片上传逻辑 + +**当前实现(第316-347行):** +```go +for i, partseq := range precreateResp.BlockList { + threadG.Go(func(ctx context.Context) error { + section := io.NewSectionReader(cache, offset, size) // 需要 ReaderAt + err := d.uploadSlice(ctx, ...) + return err + }) +} +``` + +**重构方案:** + +```go +// 伪代码 +func (d *BaiduNetdisk) uploadChunks(ctx context.Context, stream model.FileStreamer, + precreateResp *PrecreateResp, sliceSize int64, up driver.UpdateProgress) error { + + size := stream.GetSize() + count := int((size + sliceSize - 1) / sliceSize) + + // 创建 StreamSectionReader + ss, err := stream.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return err + } + + // 使用 OrderedGroup 保证 Before 阶段有序 + thread := min(d.uploadThread, len(precreateResp.BlockList)) + threadG, upCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, + retry.Attempts(UPLOAD_RETRY_COUNT), + retry.Delay(UPLOAD_RETRY_WAIT_TIME), + retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME), + retry.DelayType(retry.BackOffDelay), + retry.RetryIf(func(err error) bool { + return !errors.Is(err, ErrUploadIDExpired) + }), + retry.LastErrorOnly(true)) + + totalParts := len(precreateResp.BlockList) + + for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(upCtx) { + break + } + if partseq < 0 { + continue + } + + i, partseq := i, partseq + offset := int64(partseq) * sliceSize + length := min(sliceSize, size-offset) + + var reader io.ReadSeeker + + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + var err error + reader, err = ss.GetSectionReader(offset, length) + return err + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + err := d.uploadSliceStream(ctx, precreateResp.UploadURL, + precreateResp.Uploadid, partseq, stream.GetName(), reader, length) + if err != nil { + return err + } + precreateResp.BlockList[i] = -1 + progress := float64(threadG.Success()+1) * 100 / float64(totalParts+1) + up(progress) + return nil + }, + After: func(err error) { + ss.FreeSectionReader(reader) + }, + }) + } + + return threadG.Wait() +} +``` + +### 3.3 表单构建方式 + +**当前实现(第430-456行):** +```go +func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, + params map[string]string, fileName string, file *io.SectionReader) error { + + b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) + mw := multipart.NewWriter(b) + _, err := mw.CreateFormFile("file", fileName) + // ... + head := bytes.NewReader(b.Bytes()[:headSize]) + tail := bytes.NewReader(b.Bytes()[headSize:]) + rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, file, tail)) + // ... +} +``` + +**重构方案:** + +修改 `uploadSlice` 函数签名,接受 `io.ReadSeeker` 而不是 `*io.SectionReader`: + +```go +// 伪代码 +func (d *BaiduNetdisk) uploadSliceStream(ctx context.Context, uploadUrl string, + uploadid string, partseq int, fileName string, reader io.ReadSeeker, size int64) error { + + params := map[string]string{ + "method": "upload", + "access_token": d.AccessToken, + "type": "tmpfile", + "path": path, + "uploadid": uploadid, + "partseq": strconv.Itoa(partseq), + } + + // 构建 multipart 表单 + b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) + mw := multipart.NewWriter(b) + _, err := mw.CreateFormFile("file", fileName) + if err != nil { + return err + } + headSize := b.Len() + err = mw.Close() + if err != nil { + return err + } + + head := bytes.NewReader(b.Bytes()[:headSize]) + tail := bytes.NewReader(b.Bytes()[headSize:]) + + // 使用 io.ReadSeeker 而不是 *io.SectionReader + rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, reader, tail)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd) + if err != nil { + return err + } + + // 设置请求参数和头 + query := req.URL.Query() + for k, v := range params { + query.Set(k, v) + } + req.URL.RawQuery = query.Encode() + req.Header.Set("Content-Type", mw.FormDataContentType()) + req.ContentLength = int64(b.Len()) + size + + // 发送请求... +} +``` + +### 3.4 uploadid过期重试 + +**当前实现(第361-375行):** +```go +if errors.Is(err, ErrUploadIDExpired) { + log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch") + newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime) + // ... + precreateResp = newPre + precreateResp.UploadURL = "" + base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + continue uploadLoop +} +``` + +**重构方案:** + +由于流式上传无法回退重读,uploadid 过期时需要特殊处理: + +```go +// 伪代码 +uploadLoop: +for range 2 { + // 重新创建 StreamSectionReader(如果是重试) + if ss == nil { + // 第一次或重试时需要重新获取流 + // 这里需要依赖上层提供可重新获取的流 + ss, err = stream.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err + } + } + + err = d.uploadChunks(ctx, ss, precreateResp, sliceSize, up) + + if err == nil { + break uploadLoop + } + + if errors.Is(err, ErrUploadIDExpired) { + log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch") + + // 重新 precreate + newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime) + if err2 != nil { + return nil, err2 + } + if newPre.ReturnType == 2 { + return fileToObj(newPre.File), nil + } + precreateResp = newPre + precreateResp.UploadURL = "" + + // 需要重新获取流 - 这是流式上传的限制 + // 如果流不支持重新读取,则无法重试 + ss = nil // 标记需要重新创建 + + base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + continue uploadLoop + } + + return nil, err +} +``` + +**重要说明:** +- 流式上传的一个限制是无法回退重读 +- 如果 uploadid 过期,需要重新获取流 +- 对于 SeekableStream,可以通过 RangeRead 重新获取 +- 对于纯流式输入,可能无法支持 uploadid 过期重试 + +## 4. 代码改造范围 + +### 4.1 需要修改的函数 + +| 函数 | 文件位置 | 修改内容 | +|------|----------|----------| +| `Put` | driver.go:191-393 | 主要重构目标,移除临时文件缓存 | +| `uploadSlice` | driver.go:430-490 | 修改参数类型,支持 io.ReadSeeker | + +### 4.2 需要新增的辅助函数 + +```go +// 1. 流式计算哈希值 +func (d *BaiduNetdisk) calculateHashesStream(ctx context.Context, stream model.FileStreamer, + sliceSize int64) (contentMd5, sliceMd5 string, blockList []string, err error) + +// 2. 流式分片上传 +func (d *BaiduNetdisk) uploadChunksStream(ctx context.Context, stream model.FileStreamer, + precreateResp *PrecreateResp, sliceSize int64, path string, + up driver.UpdateProgress) error + +// 3. 修改后的分片上传(接受 io.ReadSeeker) +func (d *BaiduNetdisk) uploadSliceStream(ctx context.Context, uploadUrl string, + uploadid string, partseq int, path string, fileName string, + reader io.ReadSeeker, size int64) error +``` + +### 4.3 需要删除的代码段 + +| 代码行 | 内容 | 原因 | +|--------|------|------| +| 202-217 | 临时文件创建逻辑 | 不再需要缓存整个文件 | +| 230-258 | 同步MD5计算循环 | 改用流式计算 | +| 259-267 | 临时文件seek和错误处理 | 不再使用临时文件 | + +### 4.4 对现有API调用的影响 + +| API | 影响 | 说明 | +|-----|------|------| +| `precreate` | 无变化 | 参数不变 | +| `uploadSlice` | 参数变化 | file 参数从 `*io.SectionReader` 改为 `io.ReadSeeker` | +| `create` | 无变化 | 参数不变 | +| `getUploadUrl` | 无变化 | 参数不变 | + +## 5. 兼容性考虑 + +### 5.1 断点续传功能 + +**保留方式:** +- 继续使用 `base.SaveUploadProgress` 和 `base.GetUploadProgress` +- 保存 `PrecreateResp` 包含 `BlockList`(未完成的分片列表) +- 恢复时跳过已完成的分片 + +**限制:** +- 流式上传无法跳过已完成的分片(流不能回退) +- 断点续传仅在有缓存文件时有效 +- 对于纯流式输入,断点续传功能受限 + +**解决方案:** +```go +// 检查是否支持断点续传 +func canResumeUpload(stream model.FileStreamer, precreateResp *PrecreateResp) bool { + // 如果有缓存文件,可以断点续传 + if stream.GetFile() != nil { + return true + } + // 如果所有分片都需要上传,可以从头开始 + if len(precreateResp.BlockList) == totalParts { + return true + } + // 否则无法断点续传 + return false +} +``` + +### 5.2 并发上传功能 + +**保留方式:** +- 使用 `errgroup.NewOrderedGroupWithContext` 替代 `errgroup.NewGroupWithContext` +- Before 阶段有序执行,保证流式读取顺序 +- Do 阶段并发执行,保持上传并发性 + +**关键点:** +- OrderedGroup 的 Before 阶段是串行的 +- Do 阶段是并发的 +- 这样既保证了流式读取的顺序性,又保持了上传的并发性 + +### 5.3 小文件处理(< 4MB) + +**当前行为:** +- 小文件也走分片上传流程 +- 只有一个分片 + +**重构后行为:** +- 保持不变 +- 小文件同样使用 StreamSectionReader +- 只创建一个分片的缓冲区 + +### 5.4 秒传(PutRapid) + +**当前实现:** +```go +if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil { + return newObj, nil +} +``` + +**重构后:** +- 保持不变 +- 秒传在流式上传之前尝试 +- 如果秒传成功,不需要实际上传数据 + +## 6. 实现步骤 + +### 6.1 阶段一:准备工作 + +1. **创建新的上传函数文件** + - 创建 `drivers/baidu_netdisk/upload.go` + - 将上传相关函数移到新文件 + +2. **添加流式哈希计算函数** + - 实现 `calculateHashesStream` 函数 + - 使用 StreamSectionReader 按分片计算 + +### 6.2 阶段二:核心重构 + +3. **修改 uploadSlice 函数** + - 创建新函数 `uploadSliceStream` + - 接受 `io.ReadSeeker` 参数 + - 保留原函数用于兼容 + +4. **实现流式分片上传** + - 创建 `uploadChunksStream` 函数 + - 使用 OrderedGroup 实现有序并发 + - 实现三阶段生命周期 + +5. **重构 Put 函数** + - 移除临时文件创建逻辑 + - 使用流式哈希计算 + - 调用流式分片上传 + +### 6.3 阶段三:完善与测试 + +6. **处理断点续传** + - 检测是否支持断点续传 + - 对于不支持的情况给出警告 + +7. **处理 uploadid 过期** + - 实现重试逻辑 + - 处理流不可重读的情况 + +8. **清理旧代码** + - 删除不再需要的临时文件逻辑 + - 删除旧的 uploadSlice 函数(如果不再需要) + +### 6.4 验证方法 + +| 步骤 | 验证方法 | +|------|----------| +| 流式哈希计算 | 对比新旧函数计算结果是否一致 | +| 分片上传 | 上传测试文件,验证文件完整性 | +| 并发上传 | 使用多线程配置,验证上传速度 | +| 断点续传 | 中断上传后恢复,验证续传功能 | +| uploadid过期 | 模拟过期场景,验证重试逻辑 | +| 小文件 | 上传小于4MB的文件 | +| 大文件 | 上传超过100MB的文件 | + +### 6.5 潜在风险和应对方案 + +| 风险 | 影响 | 应对方案 | +|------|------|----------| +| 流不可重读 | uploadid过期时无法重试 | 检测流类型,对于SeekableStream支持重试 | +| 内存占用 | 并发上传时多个分片同时在内存 | 限制并发数,使用缓冲池 | +| 哈希计算错误 | 上传失败或数据损坏 | 充分测试,对比旧实现结果 | +| 断点续传失效 | 用户体验下降 | 对于有缓存文件的情况保持支持 | + +## 7. 测试验证方案 + +### 7.1 单元测试 + +```go +// 测试流式哈希计算 +func TestCalculateHashesStream(t *testing.T) { + // 创建测试文件流 + // 对比新旧函数计算结果 +} + +// 测试分片上传 +func TestUploadChunksStream(t *testing.T) { + // Mock 百度API + // 验证分片顺序和内容 +} +``` + +### 7.2 集成测试 + +1. **小文件上传测试**(< 4MB) + - 验证单分片上传 + - 验证文件完整性 + +2. **大文件上传测试**(> 100MB) + - 验证多分片上传 + - 验证并发上传 + - 验证内存占用 + +3. **断点续传测试** + - 中断上传 + - 恢复上传 + - 验证续传正确性 + +4. **错误处理测试** + - 网络错误重试 + - uploadid过期重试 + - 取消上传 + +### 7.3 性能测试 + +1. **内存占用对比** + - 旧实现:临时文件大小 = 文件大小 + - 新实现:内存占用 ≈ 分片大小 × 并发数 + +2. **上传速度对比** + - 测试不同文件大小的上传速度 + - 测试不同并发数的影响 + +## 8. 改造前后对比 + +### 8.1 代码结构对比 + +| 方面 | 改造前 | 改造后 | +|------|--------|--------| +| 临时文件 | 需要创建整个文件大小的临时文件 | 不需要临时文件 | +| 内存占用 | 低(数据在磁盘) | 中等(分片大小 × 并发数) | +| 磁盘占用 | 高(= 文件大小) | 无 | +| 并发模型 | errgroup.NewGroupWithContext | errgroup.NewOrderedGroupWithContext | +| 分片读取 | io.NewSectionReader | StreamSectionReader | +| MD5计算 | 同步遍历整个文件 | 流式按分片计算 | + +### 8.2 API调用流程对比 + +**改造前:** +``` +1. 创建临时文件 +2. 读取流 → 写入临时文件(同时计算MD5) +3. precreate API +4. 并发上传分片(从临时文件随机读取) +5. create API +6. 删除临时文件 +``` + +**改造后:** +``` +1. 创建 StreamSectionReader +2. 流式计算MD5(按分片读取) +3. precreate API +4. 有序并发上传分片(Before有序,Do并发) +5. create API +6. 释放 StreamSectionReader +``` + +### 8.3 资源使用对比 + +| 资源 | 改造前 | 改造后 | +|------|--------|--------| +| 磁盘空间 | O(n) - 文件大小 | O(1) - 无 | +| 内存 | O(1) - 固定缓冲区 | O(k) - k为并发数×分片大小 | +| 文件句柄 | 1个临时文件 | 0个 | + +## 9. 总结 + +### 9.1 设计要点 + +1. **使用 StreamSectionReader 替代临时文件** + - 按需读取分片数据 + - 使用内存缓冲池管理内存 + - 读取后立即释放 + +2. **使用 OrderedGroup 实现有序并发** + - Before 阶段有序执行,保证流式读取顺序 + - Do 阶段并发执行,保持上传并发性 + - After 阶段释放资源 + +3. **流式计算MD5** + - 按分片读取数据 + - 同时计算多个哈希值 + - 不需要缓存整个文件 + +4. **保持兼容性** + - 保留断点续传功能(有限制) + - 保留并发上传功能 + - 保留秒传功能 + +### 9.2 实现优先级 + +1. **高优先级** + - 流式分片上传(核心功能) + - 流式MD5计算 + - OrderedGroup 并发控制 + +2. **中优先级** + - uploadid 过期重试 + - 断点续传支持 + +3. **低优先级** + - 性能优化 + - 错误处理完善 + +### 9.3 预期收益 + +1. **消除磁盘空间限制** + - 不再需要本地存储空间 ≥ 最大上传文件大小 + - 支持在存储空间有限的环境下上传大文件 + +2. **减少I/O操作** + - 不需要先写入临时文件再读取 + - 数据直接从源流传输到网络 + +3. **提高资源利用效率** + - 内存按需分配 + - 使用缓冲池复用内存 \ No newline at end of file