Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 81 additions & 36 deletions feature/s3/transfermanager/api_op_DownloadDirectory.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
Expand All @@ -19,18 +20,14 @@ import (
// DownloadDirectoryInput represents a request to the DownloadDirectory() call
type DownloadDirectoryInput struct {
// Bucket where objects are downloaded from
Bucket string
Bucket *string

// The destination directory to download
Destination string
Destination *string

// The S3 key prefix to use for listing objects. If not provided,
// all objects under a bucket will be retrieved
KeyPrefix string

// The s3 delimiter used to convert keyname to local filepath if it
// is different from local file separator
S3Delimiter string
KeyPrefix *string

// A callback func to allow users to fileter out unwanted objects
// according to bool returned from the function
Expand All @@ -39,6 +36,12 @@ type DownloadDirectoryInput struct {
// A callback function to allow customers to update individual
// GetObjectInput that the S3 Transfer Manager generates
Callback GetRequestCallback

// A callback function to allow users to control the download behavior
// when there are failed objects. The directory download will be terminated
// if its function returns non-nil error and will continue skipping current
// failed object if the function returns nil
FailurePolicy DownloadDirectoryFailurePolicy
}

// ObjectFilter is the callback to allow users to filter out unwanted objects.
Expand All @@ -56,10 +59,41 @@ type GetRequestCallback interface {
UpdateRequest(*GetObjectInput)
}

// DownloadDirectoryFailurePolicy is a callback to allow users to control the
// download behavior when there are failed objects. It is invoked for every failed object.
// If the OnDownloadFailed returns non-nil error, downloader will cancel all ongoing
// single object download requests and terminate the download directory process, if it returns nil
// error, downloader will count the current request as a failed object downloaded but continue
// getting other objects.
type DownloadDirectoryFailurePolicy interface {
OnDownloadFailed(*DownloadDirectoryInput, *GetObjectInput, error) error
}

// TerminateDownloadPolicy implements DownloadDirectoryFailurePolicy to cancel all other ongoing
// objects download and terminate the download directory call
type TerminateDownloadPolicy struct{}

// OnDownloadFailed returns the initial err
func (TerminateDownloadPolicy) OnDownloadFailed(directoryInput *DownloadDirectoryInput, objectInput *GetObjectInput, err error) error {
return err
}

// IgnoreDownloadFailurePolicy implements the DownloadDirectoryFailurePolicy to ignore single object download error
// and continue downloading other objects
type IgnoreDownloadFailurePolicy struct{}

// OnDownloadFailed ignores input error and return nil
func (IgnoreDownloadFailurePolicy) OnDownloadFailed(*DownloadDirectoryInput, *GetObjectInput, error) error {
return nil
}

// DownloadDirectoryOutput represents a response from the DownloadDirectory() call
type DownloadDirectoryOutput struct {
// Total number of objects successfully downloaded
ObjectsDownloaded int
ObjectsDownloaded int64

// Total number of objects failed to download
ObjectsFailed int64
}

type objectEntry struct {
Expand All @@ -75,13 +109,13 @@ type objectEntry struct {
// download. These options are copies of the original Options instance, the client of which DownloadDirectory is called from.
// Modifying the options will not impact the original Client and Options instance.
func (c *Client) DownloadDirectory(ctx context.Context, input *DownloadDirectoryInput, opts ...func(*Options)) (*DownloadDirectoryOutput, error) {
fileInfo, err := os.Stat(input.Destination)
fileInfo, err := os.Stat(aws.ToString(input.Destination))
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("error when getting destination folder info: %v", err)
}
} else if !fileInfo.IsDir() {
return nil, fmt.Errorf("the destination path %s doesn't point to a valid directory", input.Destination)
return nil, fmt.Errorf("the destination path %s doesn't point to a valid directory", aws.ToString(input.Destination))

}

Expand All @@ -94,11 +128,13 @@ func (c *Client) DownloadDirectory(ctx context.Context, input *DownloadDirectory
}

type directoryDownloader struct {
c *Client
options Options
in *DownloadDirectoryInput
c *Client
options Options
in *DownloadDirectoryInput
failurePolicy DownloadDirectoryFailurePolicy

objectsDownloaded int
objectsDownloaded int64
objectsFailed int64

err error

Expand All @@ -125,8 +161,8 @@ func (d *directoryDownloader) downloadDirectory(ctx context.Context) (*DownloadD
break
}
listOutput, err := d.options.S3.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(d.in.Bucket),
Prefix: nzstring(d.in.KeyPrefix),
Bucket: d.in.Bucket,
Prefix: d.in.KeyPrefix,
ContinuationToken: nzstring(continuationToken),
})
if err != nil {
Expand All @@ -139,7 +175,7 @@ func (d *directoryDownloader) downloadDirectory(ctx context.Context) (*DownloadD
break
}
key := aws.ToString(o.Key)
if strings.HasSuffix(key, "/") || strings.HasSuffix(key, d.in.S3Delimiter) {
if strings.HasSuffix(key, "/") {
continue // skip folder object
}
if d.in.Filter != nil && !d.in.Filter.FilterObject(o) {
Expand Down Expand Up @@ -167,6 +203,7 @@ func (d *directoryDownloader) downloadDirectory(ctx context.Context) (*DownloadD

out := &DownloadDirectoryOutput{
ObjectsDownloaded: d.objectsDownloaded,
ObjectsFailed: d.objectsFailed,
}

d.emitter.Complete(ctx, out)
Expand All @@ -175,26 +212,30 @@ func (d *directoryDownloader) downloadDirectory(ctx context.Context) (*DownloadD
}

func (d *directoryDownloader) init() {
if d.in.S3Delimiter == "" {
d.in.S3Delimiter = "/"
d.failurePolicy = TerminateDownloadPolicy{}
if d.in.FailurePolicy != nil {
d.failurePolicy = d.in.FailurePolicy
}

d.emitter = &directoryObjectsProgressEmitter{
Listeners: d.options.DirectoryProgressListeners,
}
}

func (d *directoryDownloader) getLocalPath(key string) (string, error) {
keyprefix := d.in.KeyPrefix
if keyprefix != "" && !strings.HasSuffix(keyprefix, d.in.S3Delimiter) {
keyprefix = keyprefix + d.in.S3Delimiter
keyprefix := aws.ToString(d.in.KeyPrefix)
delimiter := "/"
destination := aws.ToString(d.in.Destination)
if keyprefix != "" && !strings.HasSuffix(keyprefix, delimiter) {
keyprefix = keyprefix + delimiter
}
path := filepath.Join(d.in.Destination, strings.ReplaceAll(strings.TrimPrefix(key, keyprefix), d.in.S3Delimiter, string(os.PathSeparator)))
relPath, err := filepath.Rel(d.in.Destination, path)
path := filepath.Join(destination, strings.ReplaceAll(strings.TrimPrefix(key, keyprefix), delimiter, string(os.PathSeparator)))
relPath, err := filepath.Rel(destination, path)
if err != nil {
return "", err
}
if relPath == "." || strings.Contains(relPath, "..") {
return "", fmt.Errorf("resolved local path %s is outside of destination %s", path, d.in.Destination)
return "", fmt.Errorf("resolved local path %s is outside of destination %s", path, destination)
}

return path, nil
Expand All @@ -221,14 +262,19 @@ func (d *directoryDownloader) downloadObject(ctx context.Context, ch chan object

input := &GetObjectInput{
Bucket: d.in.Bucket,
Key: data.key,
Key: aws.String(data.key),
}
if d.in.Callback != nil {
d.in.Callback.UpdateRequest(input)
}
out, err := d.c.GetObject(ctx, input)
if err != nil {
d.setErr(fmt.Errorf("error when downloading object %s: %v", data.key, err))
err = d.failurePolicy.OnDownloadFailed(d.in, input, err)
if err != nil {
d.setErr(fmt.Errorf("error when heading info of object %s: %v", data.key, err))
} else {
atomic.AddInt64(&d.objectsFailed, 1)
}
continue
}

Expand All @@ -248,23 +294,22 @@ func (d *directoryDownloader) downloadObject(ctx context.Context, ch chan object
}
n, err := io.Copy(file, out.Body)
if err != nil {
d.setErr(fmt.Errorf("error when writing to local file %s: %v", data.path, err))
// where s3.GetObject is really called, must be handled by failure policy
err = d.failurePolicy.OnDownloadFailed(d.in, input, err)
if err != nil {
d.setErr(fmt.Errorf("error when getting object and writing to local file %s: %v", data.path, err))
} else {
atomic.AddInt64(&d.objectsFailed, 1)
}
os.Remove(data.path)
continue
}

d.incrObjectsDownloaded(1)
atomic.AddInt64(&d.objectsDownloaded, 1)
d.emitter.ObjectsTransferred(ctx, n)
}
}

func (d *directoryDownloader) incrObjectsDownloaded(n int) {
d.mu.Lock()
defer d.mu.Unlock()

d.objectsDownloaded += n
}

func (d *directoryDownloader) setErr(err error) {
d.mu.Lock()
defer d.mu.Unlock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,6 @@ func TestInteg_DownloadDirectory(t *testing.T) {
ExpectObjectsDownloaded: 3,
ExpectFiles: []string{"bar", "oiibaz/zoo", "baz/zoo"},
},
"multi file with prefix and custom delimiter": {
ObjectsSize: map[string]int64{
"yee#bar": 2 * 1024 * 1024,
"yee#baz#": 0,
"yee#baz#zoo": 10 * 1024 * 1024,
"yee#oii@zoo": 10 * 1024 * 1024,
"yee#yee#..#bla": 2 * 1024 * 1024,
"ye": 20 * 1024 * 1024,
},
KeyPrefix: "yee#",
Delimiter: "#",
ExpectObjectsDownloaded: 4,
ExpectFiles: []string{"bar", "baz/zoo", "oii@zoo", "bla"},
},
}

for name, c := range cases {
Expand Down
Loading
Loading