From 0c9dcec9cdcb0416e630116adee090b4640635e8 Mon Sep 17 00:00:00 2001 From: jack roble <74554363+0daysseus@users.noreply.github.com> Date: Fri, 19 Apr 2024 05:22:16 -0400 Subject: [PATCH 1/5] fix: init storages in order (#6346) --- internal/bootstrap/storage.go | 4 ++-- internal/db/storage.go | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/bootstrap/storage.go b/internal/bootstrap/storage.go index 44af99c56e12..86288edcb161 100644 --- a/internal/bootstrap/storage.go +++ b/internal/bootstrap/storage.go @@ -21,8 +21,8 @@ func LoadStorages() { if err != nil { utils.Log.Errorf("failed get enabled storages: %+v", err) } else { - utils.Log.Infof("success load storage: [%s], driver: [%s]", - storages[i].MountPath, storages[i].Driver) + utils.Log.Infof("success load storage: [%s], driver: [%s], order: [%d]", + storages[i].MountPath, storages[i].Driver, storages[i].Order) } } conf.StoragesLoaded = true diff --git a/internal/db/storage.go b/internal/db/storage.go index 105bc0aafda3..d4e0730f0641 100644 --- a/internal/db/storage.go +++ b/internal/db/storage.go @@ -2,6 +2,7 @@ package db import ( "fmt" + "sort" "github.com/alist-org/alist/v3/internal/model" "github.com/pkg/errors" @@ -65,5 +66,8 @@ func GetEnabledStorages() ([]model.Storage, error) { if err := db.Where(fmt.Sprintf("%s = ?", columnName("disabled")), false).Find(&storages).Error; err != nil { return nil, errors.WithStack(err) } + sort.Slice(storages, func(i, j int) bool { + return storages[i].Order < storages[j].Order + }) return storages, nil } From 32ddab9b0131885045971b3d6bd59605d4ffc9ac Mon Sep 17 00:00:00 2001 From: Xiaoran Studio Date: Wed, 24 Apr 2024 14:54:01 +0800 Subject: [PATCH 2/5] feat(123_share): add access token (#6357) --- drivers/123_share/driver.go | 11 ++++++++ drivers/123_share/meta.go | 3 ++- drivers/123_share/util.go | 50 +++++++++++++++++++++++++++++++------ 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/drivers/123_share/driver.go b/drivers/123_share/driver.go index b2fd4313331b..7fca7cc145e8 100644 --- a/drivers/123_share/driver.go +++ b/drivers/123_share/driver.go @@ -4,8 +4,11 @@ import ( "context" "encoding/base64" "fmt" + "golang.org/x/time/rate" "net/http" "net/url" + "sync" + "time" "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" @@ -19,6 +22,7 @@ import ( type Pan123Share struct { model.Storage Addition + apiRateLimit sync.Map } func (d *Pan123Share) Config() driver.Config { @@ -146,4 +150,11 @@ func (d *Pan123Share) Put(ctx context.Context, dstDir model.Obj, stream model.Fi // return nil, errs.NotSupport //} +func (d *Pan123Share) APIRateLimit(api string) bool { + limiter, _ := d.apiRateLimit.LoadOrStore(api, + rate.NewLimiter(rate.Every(time.Millisecond*700), 1)) + ins := limiter.(*rate.Limiter) + return ins.Allow() +} + var _ driver.Driver = (*Pan123Share)(nil) diff --git a/drivers/123_share/meta.go b/drivers/123_share/meta.go index a4bb14a95932..ce39b7eee07c 100644 --- a/drivers/123_share/meta.go +++ b/drivers/123_share/meta.go @@ -7,10 +7,11 @@ import ( type Addition struct { ShareKey string `json:"sharekey" required:"true"` - SharePwd string `json:"sharepassword" required:"true"` + SharePwd string `json:"sharepassword"` driver.RootID OrderBy string `json:"order_by" type:"select" options:"file_name,size,update_at" default:"file_name"` OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + AccessToken string `json:"accesstoken" type:"text"` } var config = driver.Config{ diff --git a/drivers/123_share/util.go b/drivers/123_share/util.go index bfce54f3cc07..b22b7cc45474 100644 --- a/drivers/123_share/util.go +++ b/drivers/123_share/util.go @@ -2,8 +2,15 @@ package _123Share import ( "errors" + "fmt" + "hash/crc32" + "math" + "math/rand" "net/http" + "net/url" "strconv" + "strings" + "time" "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/pkg/utils" @@ -15,20 +22,45 @@ const ( Api = "https://www.123pan.com/api" AApi = "https://www.123pan.com/a/api" BApi = "https://www.123pan.com/b/api" - MainApi = Api + MainApi = BApi FileList = MainApi + "/share/get" DownloadInfo = MainApi + "/share/download/info" //AuthKeySalt = "8-8D$sL8gPjom7bk#cY" ) +func signPath(path string, os string, version string) (k string, v string) { + table := []byte{'a', 'd', 'e', 'f', 'g', 'h', 'l', 'm', 'y', 'i', 'j', 'n', 'o', 'p', 'k', 'q', 'r', 's', 't', 'u', 'b', 'c', 'v', 'w', 's', 'z'} + random := fmt.Sprintf("%.f", math.Round(1e7*rand.Float64())) + now := time.Now().In(time.FixedZone("CST", 8*3600)) + timestamp := fmt.Sprint(now.Unix()) + nowStr := []byte(now.Format("200601021504")) + for i := 0; i < len(nowStr); i++ { + nowStr[i] = table[nowStr[i]-48] + } + timeSign := fmt.Sprint(crc32.ChecksumIEEE(nowStr)) + data := strings.Join([]string{timestamp, random, path, os, version, timeSign}, "|") + dataSign := fmt.Sprint(crc32.ChecksumIEEE([]byte(data))) + return timeSign, strings.Join([]string{timestamp, random, dataSign}, "-") +} + +func GetApi(rawUrl string) string { + u, _ := url.Parse(rawUrl) + query := u.Query() + query.Add(signPath(u.Path, "web", "3")) + u.RawQuery = query.Encode() + return u.String() +} + func (d *Pan123Share) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { req := base.RestyClient.R() req.SetHeaders(map[string]string{ - "origin": "https://www.123pan.com", - "referer": "https://www.123pan.com/", - "user-agent": "Dart/2.19(dart:io)", - "platform": "android", - "app-version": "36", + "origin": "https://www.123pan.com", + "referer": "https://www.123pan.com/", + "authorization": "Bearer " + d.AccessToken, + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) alist-client", + "platform": "web", + "app-version": "3", + //"user-agent": base.UserAgent, }) if callback != nil { callback(req) @@ -36,7 +68,7 @@ func (d *Pan123Share) request(url string, method string, callback base.ReqCallba if resp != nil { req.SetResult(resp) } - res, err := req.Execute(method, url) + res, err := req.Execute(method, GetApi(url)) if err != nil { return nil, err } @@ -52,6 +84,10 @@ func (d *Pan123Share) getFiles(parentId string) ([]File, error) { page := 1 res := make([]File, 0) for { + if !d.APIRateLimit(FileList) { + time.Sleep(time.Millisecond * 200) + continue + } var resp Files query := map[string]string{ "limit": "100", From 479fc6d4663aa5fd137ea56f7b6ba81377c6442e Mon Sep 17 00:00:00 2001 From: potoo <34411681+potoo0@users.noreply.github.com> Date: Wed, 24 Apr 2024 17:13:30 +0800 Subject: [PATCH 3/5] fix(webdav): make sure `Mtime` after `Ctime` (#6372 close #6371) * fix(server/webdav) make sure Mtime >= Ctime * fix(server/webdav) avoid variable 'stream' collides with imported package name --- server/webdav/util.go | 13 +++++++++---- server/webdav/webdav.go | 10 +++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/server/webdav/util.go b/server/webdav/util.go index 15d9e07cc560..a2a1641ce62f 100644 --- a/server/webdav/util.go +++ b/server/webdav/util.go @@ -8,16 +8,21 @@ import ( ) func (h *Handler) getModTime(r *http.Request) time.Time { - return h.getHeaderTime(r, "X-OC-Mtime") + return h.getHeaderTime(r, "X-OC-Mtime", "") } -// owncloud/ nextcloud haven't impl this, but we can add the support since rclone may support this soon +// owncloud/ nextcloud haven't impl this, but we can add the support since rclone may support this soon. +// try ModTime if CreateTime not found in header func (h *Handler) getCreateTime(r *http.Request) time.Time { - return h.getHeaderTime(r, "X-OC-Ctime") + return h.getHeaderTime(r, "X-OC-Ctime", "X-OC-Mtime") } -func (h *Handler) getHeaderTime(r *http.Request, header string) time.Time { +func (h *Handler) getHeaderTime(r *http.Request, header, alternative string) time.Time { hVal := r.Header.Get(header) + // try alternative + if hVal == "" && alternative != "" { + hVal = r.Header.Get(alternative) + } if hVal != "" { modTimeUnix, err := strconv.ParseInt(hVal, 10, 64) if err == nil { diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index 390e54099761..6054991a0c2b 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -331,21 +331,21 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, Modified: h.getModTime(r), Ctime: h.getCreateTime(r), } - stream := &stream.FileStream{ + fsStream := &stream.FileStream{ Obj: &obj, Reader: r.Body, Mimetype: r.Header.Get("Content-Type"), } - if stream.Mimetype == "" { - stream.Mimetype = utils.GetMimeType(reqPath) + if fsStream.Mimetype == "" { + fsStream.Mimetype = utils.GetMimeType(reqPath) } - err = fs.PutDirectly(ctx, path.Dir(reqPath), stream) + err = fs.PutDirectly(ctx, path.Dir(reqPath), fsStream) if errs.IsNotFoundError(err) { return http.StatusNotFound, err } _ = r.Body.Close() - _ = stream.Close() + _ = fsStream.Close() // TODO(rost): Returning 405 Method Not Allowed might not be appropriate. if err != nil { return http.StatusMethodNotAllowed, err From ec08ecdf6cdb69bb776d5fbb539a842ef9c946d3 Mon Sep 17 00:00:00 2001 From: potoo <34411681+potoo0@users.noreply.github.com> Date: Thu, 25 Apr 2024 20:08:20 +0800 Subject: [PATCH 4/5] fix(baidu_netdisk): cached Ctime/Mtime (#6373 close #6370) (cherry picked from commit 23542541e4f343d484de1f83ee5c928d2ab6753c) --- drivers/baidu_netdisk/driver.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 20810a768dec..43da834a143c 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -165,9 +165,16 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo if err != nil { return nil, err } + // 修复时间,具体原因见 Put 方法注释的 **注意** + newFile.Ctime = stream.CreateTime().Unix() + newFile.Mtime = stream.ModTime().Unix() return fileToObj(newFile), nil } +// Put +// +// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。 +// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致 func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { // rapid upload if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil { @@ -245,9 +252,9 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F log.Debugf("%+v", precreateResp) if precreateResp.ReturnType == 2 { //rapid upload, since got md5 match from baidu server - if err != nil { - return nil, err - } + // 修复时间,具体原因见 Put 方法注释的 **注意** + precreateResp.File.Ctime = ctime + precreateResp.File.Mtime = mtime return fileToObj(precreateResp.File), nil } } @@ -298,6 +305,9 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F if err != nil { return nil, err } + // 修复时间,具体原因见 Put 方法注释的 **注意** + newFile.Ctime = ctime + newFile.Mtime = mtime return fileToObj(newFile), nil } From b95df1d7457c6a72cf7e8242061e8a52fc9a3095 Mon Sep 17 00:00:00 2001 From: Mmx Date: Thu, 25 Apr 2024 20:11:15 +0800 Subject: [PATCH 5/5] perf: use io copy with buffer pool (#6389) * feat: add io methods with buffer * chore: move io.Copy calls to utils.CopyWithBuffer --- drivers/123/driver.go | 2 +- drivers/189pc/utils.go | 2 +- drivers/aliyundrive/driver.go | 2 +- drivers/aliyundrive_open/upload.go | 2 +- drivers/baidu_netdisk/driver.go | 2 +- drivers/baidu_photo/driver.go | 2 +- drivers/chaoxing/driver.go | 2 +- drivers/ilanzou/driver.go | 2 +- drivers/mediatrack/driver.go | 2 +- drivers/pikpak/util.go | 3 ++- drivers/quark_uc/driver.go | 4 ++-- drivers/smb/util.go | 4 ++-- drivers/thunder/util.go | 2 +- internal/net/request.go | 3 ++- internal/net/serve.go | 4 ++-- internal/net/util.go | 3 ++- internal/stream/stream.go | 2 +- pkg/gowebdav/client.go | 3 ++- pkg/utils/file.go | 4 ++-- pkg/utils/hash.go | 2 +- pkg/utils/hash_test.go | 3 +-- pkg/utils/io.go | 31 +++++++++++++++++++++++++++++- 22 files changed, 59 insertions(+), 27 deletions(-) diff --git a/drivers/123/driver.go b/drivers/123/driver.go index f5d981ef6361..240027405d53 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -194,7 +194,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr defer func() { _ = tempFile.Close() }() - if _, err = io.Copy(h, tempFile); err != nil { + if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { return err } _, err = tempFile.Seek(0, io.SeekStart) diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index ee96af3e1603..a000a84e0053 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -595,7 +595,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode } silceMd5.Reset() - if _, err := io.CopyN(io.MultiWriter(fileMd5, silceMd5), tempFile, byteSize); err != nil && err != io.EOF { + if _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5, silceMd5), tempFile, byteSize); err != nil && err != io.EOF { return nil, err } md5Byte := silceMd5.Sum(nil) diff --git a/drivers/aliyundrive/driver.go b/drivers/aliyundrive/driver.go index eab38f58e1c4..2a977aa35e50 100644 --- a/drivers/aliyundrive/driver.go +++ b/drivers/aliyundrive/driver.go @@ -194,7 +194,7 @@ func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.Fil } if d.RapidUpload { buf := bytes.NewBuffer(make([]byte, 0, 1024)) - io.CopyN(buf, file, 1024) + utils.CopyWithBufferN(buf, file, 1024) reqBody["pre_hash"] = utils.HashData(utils.SHA1, buf.Bytes()) if localFile != nil { if _, err := localFile.Seek(0, io.SeekStart); err != nil { diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index 5f57e8b56200..d152836c075b 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -136,7 +136,7 @@ func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error if err != nil { return "", err } - _, err = io.CopyN(buf, reader, length) + _, err = utils.CopyWithBufferN(buf, reader, length) if err != nil { return "", err } diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 43da834a143c..ad52a4b54384 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -211,7 +211,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F if i == count { byteSize = lastBlockSize } - _, err := io.CopyN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) if err != nil && err != io.EOF { return nil, err } diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go index c29bc110095a..7477a8eb5277 100644 --- a/drivers/baidu_photo/driver.go +++ b/drivers/baidu_photo/driver.go @@ -261,7 +261,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil if i == count { byteSize = lastBlockSize } - _, err := io.CopyN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) if err != nil && err != io.EOF { return nil, err } diff --git a/drivers/chaoxing/driver.go b/drivers/chaoxing/driver.go index 143235fa481f..de122c36c4d3 100644 --- a/drivers/chaoxing/driver.go +++ b/drivers/chaoxing/driver.go @@ -229,7 +229,7 @@ func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileS if err != nil { return err } - _, err = io.Copy(filePart, stream) + _, err = utils.CopyWithBuffer(filePart, stream) if err != nil { return err } diff --git a/drivers/ilanzou/driver.go b/drivers/ilanzou/driver.go index 341136da1cd8..1d8e5d36b09b 100644 --- a/drivers/ilanzou/driver.go +++ b/drivers/ilanzou/driver.go @@ -271,7 +271,7 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt defer func() { _ = tempFile.Close() }() - if _, err = io.Copy(h, tempFile); err != nil { + if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { return nil, err } _, err = tempFile.Seek(0, io.SeekStart) diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index ef571832eb70..f0f1ded00872 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -206,7 +206,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil return err } h := md5.New() - _, err = io.Copy(h, tempFile) + _, err = utils.CopyWithBuffer(h, tempFile) if err != nil { return err } diff --git a/drivers/pikpak/util.go b/drivers/pikpak/util.go index 02b988bcd64f..71ad1dca8a3e 100644 --- a/drivers/pikpak/util.go +++ b/drivers/pikpak/util.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/hex" "errors" + "github.com/alist-org/alist/v3/pkg/utils" "io" "net/http" @@ -141,7 +142,7 @@ func getGcid(r io.Reader, size int64) (string, error) { readSize := calcBlockSize(size) for { hash2.Reset() - if n, err := io.CopyN(hash2, r, readSize); err != nil && n == 0 { + if n, err := utils.CopyWithBufferN(hash2, r, readSize); err != nil && n == 0 { if err != io.EOF { return "", err } diff --git a/drivers/quark_uc/driver.go b/drivers/quark_uc/driver.go index 291189ce088d..8674fbab26fe 100644 --- a/drivers/quark_uc/driver.go +++ b/drivers/quark_uc/driver.go @@ -143,7 +143,7 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File _ = tempFile.Close() }() m := md5.New() - _, err = io.Copy(m, tempFile) + _, err = utils.CopyWithBuffer(m, tempFile) if err != nil { return err } @@ -153,7 +153,7 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File } md5Str := hex.EncodeToString(m.Sum(nil)) s := sha1.New() - _, err = io.Copy(s, tempFile) + _, err = utils.CopyWithBuffer(s, tempFile) if err != nil { return err } diff --git a/drivers/smb/util.go b/drivers/smb/util.go index f4605536da7f..d9fbf6c5a5a3 100644 --- a/drivers/smb/util.go +++ b/drivers/smb/util.go @@ -1,7 +1,7 @@ package smb import ( - "io" + "github.com/alist-org/alist/v3/pkg/utils" "io/fs" "net" "os" @@ -74,7 +74,7 @@ func (d *SMB) CopyFile(src, dst string) error { } defer dstfd.Close() - if _, err = io.Copy(dstfd, srcfd); err != nil { + if _, err = utils.CopyWithBuffer(dstfd, srcfd); err != nil { return err } if srcinfo, err = d.fs.Stat(src); err != nil { diff --git a/drivers/thunder/util.go b/drivers/thunder/util.go index f6dec3260cf2..3ec8db58ffeb 100644 --- a/drivers/thunder/util.go +++ b/drivers/thunder/util.go @@ -190,7 +190,7 @@ func getGcid(r io.Reader, size int64) (string, error) { readSize := calcBlockSize(size) for { hash2.Reset() - if n, err := io.CopyN(hash2, r, readSize); err != nil && n == 0 { + if n, err := utils.CopyWithBufferN(hash2, r, readSize); err != nil && n == 0 { if err != io.EOF { return "", err } diff --git a/internal/net/request.go b/internal/net/request.go index 71f45aa7afc7..088ff66ab4ff 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/alist-org/alist/v3/pkg/utils" "io" "math" "net/http" @@ -271,7 +272,7 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int } } - n, err := io.Copy(ch.buf, resp.Body) + n, err := utils.CopyWithBuffer(ch.buf, resp.Body) if err != nil { return n, &errReadingBody{err: err} diff --git a/internal/net/serve.go b/internal/net/serve.go index a05667807593..adee75ae1d6c 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -162,7 +162,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time pw.CloseWithError(err) return } - if _, err := io.CopyN(part, reader, ra.Length); err != nil { + if _, err := utils.CopyWithBufferN(part, reader, ra.Length); err != nil { pw.CloseWithError(err) return } @@ -182,7 +182,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time w.WriteHeader(code) if r.Method != "HEAD" { - written, err := io.CopyN(w, sendContent, sendSize) + written, err := utils.CopyWithBufferN(w, sendContent, sendSize) if err != nil { log.Warnf("ServeHttp error. err: %s ", err) if written != sendSize { diff --git a/internal/net/util.go b/internal/net/util.go index 4347e2c404df..442018594874 100644 --- a/internal/net/util.go +++ b/internal/net/util.go @@ -2,6 +2,7 @@ package net import ( "fmt" + "github.com/alist-org/alist/v3/pkg/utils" "io" "math" "mime/multipart" @@ -330,7 +331,7 @@ func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.Rea log.Warnf("offset is more than 100MB, if loading data from internet, high-latency and wasting of bandwidth is expected") } - if _, err := io.Copy(io.Discard, io.LimitReader(readCloser, offset)); err != nil { + if _, err := utils.CopyWithBuffer(io.Discard, io.LimitReader(readCloser, offset)); err != nil { return nil, err } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 4b882c519e09..40482f45a36c 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -104,7 +104,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { if httpRange.Start == 0 && httpRange.Length <= InMemoryBufMaxSizeBytes && f.peekBuff == nil { bufSize := utils.Min(httpRange.Length, f.GetSize()) newBuf := bytes.NewBuffer(make([]byte, 0, bufSize)) - n, err := io.CopyN(newBuf, f.Reader, bufSize) + n, err := utils.CopyWithBufferN(newBuf, f.Reader, bufSize) if err != nil { return nil, err } diff --git a/pkg/gowebdav/client.go b/pkg/gowebdav/client.go index 2fca0b7f43db..cef501b9a152 100644 --- a/pkg/gowebdav/client.go +++ b/pkg/gowebdav/client.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/xml" "fmt" + "github.com/alist-org/alist/v3/pkg/utils" "io" "net/http" "net/url" @@ -419,7 +420,7 @@ func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadClos // stream in rs.Body if rs.StatusCode == 200 { // discard first 'offset' bytes. - if _, err := io.Copy(io.Discard, io.LimitReader(rs.Body, offset)); err != nil { + if _, err := utils.CopyWithBuffer(io.Discard, io.LimitReader(rs.Body, offset)); err != nil { return nil, newPathErrorErr("ReadStreamRange", path, err) } diff --git a/pkg/utils/file.go b/pkg/utils/file.go index 7ae07158998a..54247636dcbd 100644 --- a/pkg/utils/file.go +++ b/pkg/utils/file.go @@ -32,7 +32,7 @@ func CopyFile(src, dst string) error { } defer dstfd.Close() - if _, err = io.Copy(dstfd, srcfd); err != nil { + if _, err = CopyWithBuffer(dstfd, srcfd); err != nil { return err } if srcinfo, err = os.Stat(src); err != nil { @@ -121,7 +121,7 @@ func CreateTempFile(r io.Reader, size int64) (*os.File, error) { if err != nil { return nil, err } - readBytes, err := io.Copy(f, r) + readBytes, err := CopyWithBuffer(f, r) if err != nil { _ = os.Remove(f.Name()) return nil, errs.NewErr(err, "CreateTempFile failed") diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go index 8f8aaa26781e..fa06bcc24c2c 100644 --- a/pkg/utils/hash.go +++ b/pkg/utils/hash.go @@ -96,7 +96,7 @@ func HashData(hashType *HashType, data []byte, params ...any) string { // HashReader get hash of one hashType from a reader func HashReader(hashType *HashType, reader io.Reader, params ...any) (string, error) { h := hashType.NewFunc(params...) - _, err := io.Copy(h, reader) + _, err := CopyWithBuffer(h, reader) if err != nil { return "", errs.NewErr(err, "HashReader error") } diff --git a/pkg/utils/hash_test.go b/pkg/utils/hash_test.go index 55713c1afb34..0f5a2a3b14e6 100644 --- a/pkg/utils/hash_test.go +++ b/pkg/utils/hash_test.go @@ -4,7 +4,6 @@ import ( "bytes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "io" "testing" ) @@ -36,7 +35,7 @@ var hashTestSet = []hashTest{ func TestMultiHasher(t *testing.T) { for _, test := range hashTestSet { mh := NewMultiHasher([]*HashType{MD5, SHA1, SHA256}) - n, err := io.Copy(mh, bytes.NewBuffer(test.input)) + n, err := CopyWithBuffer(mh, bytes.NewBuffer(test.input)) require.NoError(t, err) assert.Len(t, test.input, int(n)) hashInfo := mh.GetHashInfo() diff --git a/pkg/utils/io.go b/pkg/utils/io.go index 6852e28a83da..7be989c3fd78 100644 --- a/pkg/utils/io.go +++ b/pkg/utils/io.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "sync" "time" "golang.org/x/exp/constraints" @@ -29,7 +30,7 @@ func CopyWithCtx(ctx context.Context, out io.Writer, in io.Reader, size int64, p // possible in the call process. var finish int64 = 0 s := size / 100 - _, err := io.Copy(out, readerFunc(func(p []byte) (int, error) { + _, err := CopyWithBuffer(out, readerFunc(func(p []byte) (int, error) { // golang non-blocking channel: https://gobyexample.com/non-blocking-channel-operations select { // if context has been canceled @@ -204,3 +205,31 @@ func Max[T constraints.Ordered](a, b T) T { } return a } + +var IoBuffPool = &sync.Pool{ + New: func() interface{} { + return make([]byte, 32*1024*2) // Two times of size in io package + }, +} + +func CopyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { + buff := IoBuffPool.Get().([]byte) + defer IoBuffPool.Put(buff) + written, err = io.CopyBuffer(dst, src, buff) + if err != nil { + return + } + return written, nil +} + +func CopyWithBufferN(dst io.Writer, src io.Reader, n int64) (written int64, err error) { + written, err = CopyWithBuffer(dst, io.LimitReader(src, n)) + if written == n { + return n, nil + } + if written < n && err == nil { + // src stopped early; must have been EOF. + err = io.EOF + } + return +}