diff --git a/conf/conf.go b/conf/conf.go index 7dc3851..a544d4a 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -30,6 +30,7 @@ const ( BaiduOpenApiDomain = "https://openapi.baidu.com" OpenApiDomain = "https://pan.baidu.com" PcsDataDomain = "https://d.pcs.baidu.com" + PcsApiDomain = "https://pcs.baidu.com" ) // 测试参数 diff --git a/examples/file_download.go b/examples/file_download.go index 5f9a5cc..1bac5d3 100644 --- a/examples/file_download.go +++ b/examples/file_download.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/jsyzchen/pan/conf" "github.com/jsyzchen/pan/file" ) @@ -23,8 +24,17 @@ func main() { fsID = 759719327699432 fileDownloader = file.NewDownloaderWithFsID(accessToken, fsID, localFilePath) if err := fileDownloader.Download(); err != nil { - fmt.Println("2.fileDownloader.Download failed, err:", err) + fmt.Println("2.fileDownloader.DownloadWithFsID failed, err:", err) return } fmt.Println("2.fileDownloader.Download success") + + // 方式3:通过文件路径下载,非开放平台公开接口,生产环境谨慎使用 + fileDownloader = file.NewDownloaderWithPath(conf.TestData.AccessToken, conf.TestData.Path, conf.TestData.LocalFilePath) + err := fileDownloader.Download() + if err != nil { + fmt.Println("3.fileDownloader.DownloaderWithPath failed, err:", err) + return + } + fmt.Println("3.fileDownloader.DownloaderWithPath success") } \ No newline at end of file diff --git a/file/download.go b/file/download.go index 4bb8cd1..aa93c5a 100644 --- a/file/download.go +++ b/file/download.go @@ -2,8 +2,11 @@ package file import ( "errors" + "github.com/jsyzchen/pan/account" + "github.com/jsyzchen/pan/conf" "github.com/jsyzchen/pan/utils/file" "log" + "net/url" ) type Downloader struct { @@ -15,7 +18,11 @@ type Downloader struct { TotalPart int } -func NewDownloader(accessToken string, downloadLink string, localFilePath string, ) *Downloader { +const ( + PcsFileDownloadUri = "/rest/2.0/pcs/file?method=download" +) + +func NewDownloader(accessToken string, downloadLink string, localFilePath string) *Downloader { return &Downloader{ AccessToken: accessToken, LocalFilePath: localFilePath, @@ -31,13 +38,14 @@ func NewDownloaderWithFsID(accessToken string, fsID uint64, localFilePath string } } -//func NewDownloaderWithPath(accessToken string, path string, localFilePath string) *Downloader { -// return &Downloader{ -// AccessToken: accessToken, -// Path: path, -// LocalFilePath: localFilePath, -// } -//} +// 非开放平台公开接口,生产环境谨慎使用 +func NewDownloaderWithPath(accessToken string, path string, localFilePath string) *Downloader { + return &Downloader{ + AccessToken: accessToken, + Path: path, + LocalFilePath: localFilePath, + } +} // 执行下载 func (d *Downloader) Download() error { @@ -61,8 +69,12 @@ func (d *Downloader) Download() error { return errors.New("file don't exist") } downloadLink = metas.List[0].DLink - } else if d.Path != "" { - + } else if d.Path != "" { // TODO 如何通过文件路径获取下载地址 + v := url.Values{} + v.Add("path", d.Path) + v.Add("access_token", d.AccessToken) + body := v.Encode() + downloadLink = conf.PcsApiDomain + PcsFileDownloadUri + "&" + body } else { return errors.New("param error") } @@ -73,6 +85,16 @@ func (d *Downloader) Download() error { downloadLink += "&access_token=" + d.AccessToken downloader := file.NewFileDownloader(downloadLink, d.LocalFilePath) + + accountClient := account.NewAccountClient(d.AccessToken) + if userInfo, err := accountClient.UserInfo(); err == nil { + log.Println("VipType:", userInfo.VipType) + if userInfo.VipType == 2 { //当前用户是超级会员 + downloader.SetPartSize(52428800) //设置每分片下载文件大小,50M + downloader.SetCoroutineNum(10) //分片下载并发数,普通用户不支持并发分片下载 + } + } + if err := downloader.Download(); err != nil { log.Println("download failed, err:", err) return err diff --git a/file/download_test.go b/file/download_test.go index acff2da..4c1a572 100644 --- a/file/download_test.go +++ b/file/download_test.go @@ -15,3 +15,15 @@ func TestDownload(t *testing.T) { } } +func TestDownloaderWithPath(t *testing.T) { + fileDownloader := NewDownloaderWithPath(conf.TestData.AccessToken, conf.TestData.Path, conf.TestData.LocalFilePath) + err := fileDownloader.Download() + if err != nil { + t.Fail() + } else { + t.Logf("TestDownload Success") + } +} + + + diff --git a/utils/file/download.go b/utils/file/download.go index 6ab9055..1f1eb1e 100644 --- a/utils/file/download.go +++ b/utils/file/download.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "path" + "path/filepath" "strconv" "sync" "time" @@ -23,6 +24,7 @@ type Downloader struct { TotalPart int //下载线程 DoneFilePart []Part PartSize int + PartCoroutineNum int //分片下载协程数 } //filePart 文件分片 @@ -40,6 +42,8 @@ func NewFileDownloader(downloadLink, filePath string) *Downloader { FileSize: 0, Link: downloadLink, FilePath: filePath, + PartSize: 10485760,// 10M + PartCoroutineNum: 1, } } @@ -47,8 +51,12 @@ func (d *Downloader) SetTotalPart(totalPart int) { d.TotalPart = totalPart } -func (d *Downloader) SetPartSize(PartSize int) { - d.PartSize = PartSize +func (d *Downloader) SetPartSize(partSize int) { + d.PartSize = partSize +} + +func (d *Downloader) SetCoroutineNum(partCoroutineNum int) { + d.PartCoroutineNum = partCoroutineNum } //Run 开始下载任务 @@ -68,6 +76,8 @@ func (d *Downloader) Download() error { d.PartSize = 10485760 // 10M } + log.Println("fileTotalSize:", fileTotalSize) + if isSupportRange == false || fileTotalSize <= d.PartSize {//不支持Range下载或者文件比较小,直接下载文件 err := d.downloadWhole() return err @@ -89,6 +99,8 @@ func (d *Downloader) Download() error { jobs := make([]Part, d.TotalPart) eachSize := fileTotalSize / d.TotalPart + log.Println("eachSize:", eachSize) + for i := range jobs { jobs[i].Index = i if i == 0 { @@ -109,7 +121,11 @@ func (d *Downloader) Download() error { var wg sync.WaitGroup isFailed := false - sem := make(chan int, 10) //限制并发数,以防大文件下载导致占用服务器大量网络宽带和磁盘io + partCoroutineNum := d.PartCoroutineNum + if len(jobs) < partCoroutineNum { + partCoroutineNum = len(jobs) + } + sem := make(chan int, partCoroutineNum) //限制并发数,以防大文件下载导致占用服务器大量网络宽带和磁盘io for _, job := range jobs { wg.Add(1) sem <- 1 //当通道已满的时候将被阻塞 @@ -173,19 +189,22 @@ func (d *Downloader) downloadPart(c Part) error { if err != nil { return err } - if resp.StatusCode > 299 { - return errors.New(fmt.Sprintf("服务器错误状态码: %v", resp.StatusCode)) - } defer resp.Body.Close() bs, err := ioutil.ReadAll(resp.Body) + if resp.StatusCode > 299 { + log.Println(fmt.Sprintf("服务器错误,状态码: %v, msg:%s", resp.StatusCode, string(bs))) + return errors.New(fmt.Sprintf("服务器错误,状态码: %v, msg:%s", resp.StatusCode, string(bs))) + } + if err != nil { if err != io.EOF && err != io.ErrUnexpectedEOF {//unexpected EOF 处理 log.Println("ioutil.ReadAll error :", err) return err } } + if len(bs) != (c.To - c.From + 1) { - return errors.New("下载文件分片长度错误") + return errors.New(fmt.Sprintf("下载文件分片长度错误, len bs:%d", len(bs))) } //c.Data = bs @@ -194,11 +213,15 @@ func (d *Downloader) downloadPart(c Part) error { fileNamePrefix := fileName[0:len(path.Base(d.FilePath)) - len(path.Ext(d.FilePath))] nowTime := time.Now().UnixNano() / 1e6 partFilePath := path.Join(os.TempDir(), fileNamePrefix + "_" + strconv.Itoa(c.Index) + "_" + strconv.FormatInt(nowTime, 10)) + + log.Printf("partFilePath[%d]:%s", c.Index, partFilePath) + f, err := os.Create(partFilePath) if err != nil { log.Println("open file error :", err) return err } + // 关闭文件 defer f.Close() // 字节方式写入 @@ -207,6 +230,7 @@ func (d *Downloader) downloadPart(c Part) error { log.Println(err) return err } + c.FilePath = partFilePath d.DoneFilePart[c.Index] = c @@ -218,6 +242,21 @@ func (d *Downloader) downloadPart(c Part) error { //mergeFileParts 合并下载的文件 func (d *Downloader) mergeFileParts() error { log.Println("开始合并文件") + + //存储文件夹不存在的话先创建文件夹 + fileDir := filepath.Dir(d.FilePath) + _, err := os.Stat(fileDir) + if err != nil { + if os.IsNotExist(err){ + //递归创建文件夹 + err := os.MkdirAll(fileDir, os.ModePerm) + if err != nil{ + log.Println("MkdirAll failed:", err) + return err + } + } + } + mergedFile, err := os.Create(d.FilePath) if err != nil { return err