Skip to content

Commit

Permalink
Add caching support for remote directories
Browse files Browse the repository at this point in the history
  • Loading branch information
pbitty committed Aug 21, 2024
1 parent 5a504e3 commit 4175f72
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 56 deletions.
68 changes: 39 additions & 29 deletions task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,8 @@ func TestIncludesMultiLevel(t *testing.T) {
func TestIncludesRemote(t *testing.T) {
dir := "testdata/includes_remote"

os.RemoveAll(filepath.Join(dir, ".task"))

srv := httptest.NewServer(http.FileServer(http.Dir(dir)))
defer srv.Close()

Expand Down Expand Up @@ -1121,40 +1123,48 @@ func TestIncludesRemote(t *testing.T) {

var buff SyncBuffer

executors := []*task.Executor{
executors := []struct {
name string
executor *task.Executor
}{

{
Dir: dir,
Stdout: &buff,
Stderr: &buff,
Timeout: time.Minute,
Insecure: true,
Logger: &logger.Logger{Stdout: &buff, Stderr: &buff, Verbose: true},

// Without caching
AssumeYes: true,
Download: true,
name: "online, no cache",
executor: &task.Executor{
Dir: dir,
Stdout: &buff,
Stderr: &buff,
Timeout: time.Minute,
Insecure: true,
Logger: &logger.Logger{Stdout: &buff, Stderr: &buff, Verbose: true},

// Without caching
AssumeYes: true,
Download: true,
},
},

// Disabled until caching support for directories is available
//
// {
// Dir: dir,
// Stdout: &buff,
// Stderr: &buff,
// Timeout: time.Minute,
// Insecure: true,
// Logger: &logger.Logger{Stdout: &buff, Stderr: &buff, Verbose: true},

// // With caching
// AssumeYes: false,
// Download: false,
// Offline: true,
// },
{
name: "offline, use-cache",
executor: &task.Executor{
Dir: dir,
Stdout: &buff,
Stderr: &buff,
Timeout: time.Minute,
Insecure: true,
Logger: &logger.Logger{Stdout: &buff, Stderr: &buff, Verbose: true},

// With caching
AssumeYes: false,
Download: false,
Offline: true,
},
},
}

for j, e := range executors {
t.Run(fmt.Sprint(j), func(t *testing.T) {
require.NoError(t, e.Setup())
require.NoError(t, e.executor.Setup())

for k, task := range tasks {
t.Run(task, func(t *testing.T) {
Expand All @@ -1167,7 +1177,7 @@ func TestIncludesRemote(t *testing.T) {
path := filepath.Join(dir, outputFile)
require.NoError(t, os.RemoveAll(path))

require.NoError(t, e.Run(context.Background(), &ast.Call{Task: task}))
require.NoError(t, e.executor.Run(context.Background(), &ast.Call{Task: task}))

actualContent, err := os.ReadFile(path)
require.NoError(t, err)
Expand All @@ -1177,7 +1187,7 @@ func TestIncludesRemote(t *testing.T) {

for _, task := range tc.extraTasks {
t.Run(task, func(t *testing.T) {
require.NoError(t, e.Run(context.Background(), &ast.Call{Task: task}))
require.NoError(t, e.executor.Run(context.Background(), &ast.Call{Task: task}))
})
}
})
Expand Down
78 changes: 65 additions & 13 deletions taskfile/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,60 @@ func NewCache(dir string) (*Cache, error) {
func checksum(b []byte) string {
h := sha256.New()
h.Write(b)
return fmt.Sprintf("%x", h.Sum(nil))
return fmt.Sprintf("%x", h.Sum(nil))[:16]
}

func (c *Cache) write(node Node, b []byte) error {
return os.WriteFile(c.cacheFilePath(node), b, 0o644)
}
func (c *Cache) write(node Node, source *source) (*source, error) {
p := c.contentsPath(node)

fi, err := os.Stat(p)
switch {
case os.IsNotExist(err):
// Do nothing
case !fi.IsDir():
return nil, fmt.Errorf("error writing to contents path %s: not a directory", p)
default:
err := os.RemoveAll(p)
if err != nil {
return nil, fmt.Errorf("error clearing contents directory: %s", err)
}
}

if err := os.Rename(source.FileDirectory, p); err != nil {
return nil, err
}

func (c *Cache) read(node Node) ([]byte, error) {
return os.ReadFile(c.cacheFilePath(node))
if err := os.WriteFile(c.checksumFilePath(node), []byte(checksum(source.FileContent)), 0o644); err != nil {
return nil, err
}

if err := os.WriteFile(c.taskfileNamePath(node), []byte(source.Filename), 0o644); err != nil {
return nil, err
}

return c.read(node)
}

func (c *Cache) writeChecksum(node Node, checksum string) error {
return os.WriteFile(c.checksumFilePath(node), []byte(checksum), 0o644)
func (c *Cache) read(node Node) (*source, error) {
path := c.contentsPath(node)

var taskfileName string
if b, err := os.ReadFile(c.taskfileNamePath(node)); err != nil {
return nil, err
} else {
taskfileName = string(b)
}

content, err := os.ReadFile(filepath.Join(path, taskfileName))
if err != nil {
return nil, err
}

return &source{
FileContent: content,
FileDirectory: path,
Filename: taskfileName,
}, nil
}

func (c *Cache) readChecksum(node Node) string {
Expand All @@ -49,22 +90,33 @@ func (c *Cache) key(node Node) string {
return strings.TrimRight(checksum([]byte(node.Location())), "=")
}

func (c *Cache) cacheFilePath(node Node) string {
return c.filePath(node, "yaml")
}

func (c *Cache) checksumFilePath(node Node) string {
return c.filePath(node, "checksum")
}

func (c *Cache) contentsPath(node Node) string {
return c.filePath(node, "contents")
}

func (c *Cache) taskfileNamePath(node Node) string {
return c.filePath(node, "taskfileName")
}

func (c *Cache) filePath(node Node, suffix string) string {
lastDir, filename := node.FilenameAndLastDir()
prefix := filename
// Means it's not "", nor "." nor "/", so it's a valid directory
if len(lastDir) > 1 {
prefix = fmt.Sprintf("%s-%s", lastDir, filename)
}
return filepath.Join(c.dir, fmt.Sprintf("%s.%s.%s", prefix, c.key(node), suffix))

dir := filepath.Join(c.dir, fmt.Sprintf("%s.%s", prefix, c.key(node)))
if err := os.MkdirAll(dir, 0o755); err != nil {
// TODO proper error-handling
panic("error creating directory: " + err.Error())
}

return filepath.Join(dir, suffix)
}

func (c *Cache) Clear() error {
Expand Down
4 changes: 2 additions & 2 deletions taskfile/node_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func NewRemoteNode(
}

func (r *RemoteNode) Location() string {
return r.url.String()
return r.proto + "::" + r.url.String()
}

func (r *RemoteNode) Remote() bool {
Expand Down Expand Up @@ -155,7 +155,7 @@ func (r *RemoteNode) loadSource(ctx context.Context) (*source, error) {
return nil, err
}
r.client.Ctx = ctx
r.client.Src = r.proto + "::" + r.url.String()
r.client.Src = r.Location()
r.client.Dst = dir

if err := r.client.Get(); err != nil {
Expand Down
25 changes: 13 additions & 12 deletions taskfile/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ func (r *Reader) include(node Node) error {
}

func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {
var b []byte
var err error
var cache *Cache
source := &source{}
Expand All @@ -194,7 +193,7 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {

// If the file is remote and we're in offline mode, check if we have a cached copy
if node.Remote() && r.offline {
if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) {
if source, err = cache.read(node); errors.Is(err, os.ErrNotExist) {
return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()}
} else if err != nil {
return nil, err
Expand All @@ -215,7 +214,7 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout}
}
// Search for any cached copies
if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) {
if source, err = cache.read(node); errors.Is(err, os.ErrNotExist) {
return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true}
} else if err != nil {
return nil, err
Expand All @@ -225,15 +224,14 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {
return nil, err
} else {
downloaded = true
b = source.FileContent
}

// If the node was remote, we need to check the checksum
if node.Remote() && downloaded {
r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched remote copy\n", node.Location())

// Get the checksums
checksum := checksum(b)
checksum := checksum(source.FileContent)
cachedChecksum := cache.readChecksum(node)

var prompt string
Expand All @@ -252,25 +250,28 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) {

// If the hash has changed (or is new)
if checksum != cachedChecksum {
// Store the checksum
if err := cache.writeChecksum(node, checksum); err != nil {
return nil, err
}
// Cache the file
r.logger.VerboseOutf(logger.Magenta, "task: [%s] Caching downloaded file\n", node.Location())
if err = cache.write(node, b); err != nil {
if source, err = cache.write(node, source); err != nil {
return nil, err
}
}
}
}

// TODO: Find a cleaner way to override source when loading from the cache
// Without this, a Node's source gets moved, but later usages of ResolveEntrypoint
// will be relative to the old source location - before it got moved into the cache.
if n, ok := node.(*RemoteNode); ok {
n.cachedSource = source
}

var tf ast.Taskfile
if err := yaml.Unmarshal(b, &tf); err != nil {
if err := yaml.Unmarshal(source.FileContent, &tf); err != nil {
// Decode the taskfile and add the file info the any errors
taskfileInvalidErr := &errors.TaskfileDecodeError{}
if errors.As(err, &taskfileInvalidErr) {
return nil, taskfileInvalidErr.WithFileInfo(node.Location(), b, 2)
return nil, taskfileInvalidErr.WithFileInfo(node.Location(), source.FileContent, 2)
}
return nil, &errors.TaskfileInvalidError{URI: filepathext.TryAbsToRel(node.Location()), Err: err}
}
Expand Down

0 comments on commit 4175f72

Please sign in to comment.