diff --git a/.stats.yml b/.stats.yml index dbab236..377ee95 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 24 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/kernel%2Fhypeman-8fded10e90df28c07b64a92d12d665d54749b9fc13c35520667637fc596957d9.yml -openapi_spec_hash: 7374a732372bddf7f2c0b532b56ae3fb -config_hash: 510018ffa6ad6a17875954f66fe69598 +configured_endpoints: 30 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/kernel%2Fhypeman-28e78b73c796f9ee866671ed946402b5d569e683c3207d57c9143eb7d6f83fb6.yml +openapi_spec_hash: fce0ac8713369a5f048bac684ed34fc8 +config_hash: f65a6a2bcef49a9f623212f9de6d6f6f diff --git a/api.md b/api.md index 11cb961..cb5f698 100644 --- a/api.md +++ b/api.md @@ -30,6 +30,7 @@ Params Types: Response Types: - hypeman.Instance +- hypeman.PathInfo - hypeman.VolumeMount Methods: @@ -42,6 +43,7 @@ Methods: - client.Instances.Restore(ctx context.Context, id string) (hypeman.Instance, error) - client.Instances.Standby(ctx context.Context, id string) (hypeman.Instance, error) - client.Instances.Start(ctx context.Context, id string) (hypeman.Instance, error) +- client.Instances.Stat(ctx context.Context, id string, query hypeman.InstanceStatParams) (hypeman.PathInfo, error) - client.Instances.Stop(ctx context.Context, id string) (hypeman.Instance, error) ## Volumes diff --git a/go.mod b/go.mod index 52bc772..ca5e04e 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.24.0 require ( github.com/google/go-containerregistry v0.20.7 + github.com/gorilla/websocket v1.5.3 + github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 ) @@ -15,6 +17,7 @@ require ( github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/stargz-snapshotter/estargz v0.18.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/cli v29.0.3+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect @@ -32,6 +35,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -49,4 +53,5 @@ require ( golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect google.golang.org/protobuf v1.36.10 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0e49963..fa2fd63 100644 --- a/go.sum +++ b/go.sum @@ -42,10 +42,16 @@ github.com/google/go-containerregistry v0.20.7 h1:24VGNpS0IwrOZ2ms2P1QE3Xa5X9p4p github.com/google/go-containerregistry v0.20.7/go.mod h1:Lx5LCZQjLH1QBaMPeGwsME9biPeo1lPx6lbGj/UmzgM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -66,6 +72,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -125,6 +133,8 @@ google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/instance.go b/instance.go index 6d37852..d393460 100644 --- a/instance.go +++ b/instance.go @@ -144,6 +144,19 @@ func (r *InstanceService) Start(ctx context.Context, id string, opts ...option.R return } +// Returns information about a path in the guest filesystem. Useful for checking if +// a path exists, its type, and permissions before performing file operations. +func (r *InstanceService) Stat(ctx context.Context, id string, query InstanceStatParams, opts ...option.RequestOption) (res *PathInfo, err error) { + opts = slices.Concat(r.Options, opts) + if id == "" { + err = errors.New("missing required id parameter") + return + } + path := fmt.Sprintf("instances/%s/stat", id) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...) + return +} + // Stop instance (graceful shutdown) func (r *InstanceService) Stop(ctx context.Context, id string, opts ...option.RequestOption) (res *Instance, err error) { opts = slices.Concat(r.Options, opts) @@ -277,6 +290,45 @@ func (r *InstanceNetwork) UnmarshalJSON(data []byte) error { return apijson.UnmarshalRoot(data, r) } +type PathInfo struct { + // Whether the path exists + Exists bool `json:"exists,required"` + // Error message if stat failed (e.g., permission denied). Only set when exists is + // false due to an error rather than the path not existing. + Error string `json:"error,nullable"` + // True if this is a directory + IsDir bool `json:"is_dir"` + // True if this is a regular file + IsFile bool `json:"is_file"` + // True if this is a symbolic link (only set when follow_links=false) + IsSymlink bool `json:"is_symlink"` + // Symlink target path (only set when is_symlink=true) + LinkTarget string `json:"link_target,nullable"` + // File mode (Unix permissions) + Mode int64 `json:"mode"` + // File size in bytes + Size int64 `json:"size"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + Exists respjson.Field + Error respjson.Field + IsDir respjson.Field + IsFile respjson.Field + IsSymlink respjson.Field + LinkTarget respjson.Field + Mode respjson.Field + Size respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// Returns the unmodified JSON received from the API +func (r PathInfo) RawJSON() string { return r.JSON.raw } +func (r *PathInfo) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + type VolumeMount struct { // Path where volume is mounted in the guest MountPath string `json:"mount_path,required"` @@ -354,6 +406,8 @@ type InstanceNewParams struct { Size param.Opt[string] `json:"size,omitzero"` // Number of virtual CPUs Vcpus param.Opt[int64] `json:"vcpus,omitzero"` + // Device IDs or names to attach for GPU/PCI passthrough + Devices []string `json:"devices,omitzero"` // Environment variables Env map[string]string `json:"env,omitzero"` // Network configuration for the instance @@ -422,3 +476,19 @@ const ( InstanceLogsParamsSourceVmm InstanceLogsParamsSource = "vmm" InstanceLogsParamsSourceHypeman InstanceLogsParamsSource = "hypeman" ) + +type InstanceStatParams struct { + // Path to stat in the guest filesystem + Path string `query:"path,required" json:"-"` + // Follow symbolic links (like stat vs lstat) + FollowLinks param.Opt[bool] `query:"follow_links,omitzero" json:"-"` + paramObj +} + +// URLQuery serializes [InstanceStatParams]'s query parameters as `url.Values`. +func (r InstanceStatParams) URLQuery() (v url.Values, err error) { + return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{ + ArrayFormat: apiquery.ArrayQueryFormatComma, + NestedFormat: apiquery.NestedQueryFormatBrackets, + }) +} diff --git a/instance_test.go b/instance_test.go index 4910848..e7642af 100644 --- a/instance_test.go +++ b/instance_test.go @@ -195,6 +195,36 @@ func TestInstanceStart(t *testing.T) { } } +func TestInstanceStatWithOptionalParams(t *testing.T) { + t.Skip("Prism tests are disabled") + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := hypeman.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + _, err := client.Instances.Stat( + context.TODO(), + "id", + hypeman.InstanceStatParams{ + Path: "path", + FollowLinks: hypeman.Bool(true), + }, + ) + if err != nil { + var apierr *hypeman.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} + func TestInstanceStop(t *testing.T) { t.Skip("Prism tests are disabled") baseURL := "http://localhost:4010" diff --git a/lib/cp.go b/lib/cp.go new file mode 100644 index 0000000..0d39038 --- /dev/null +++ b/lib/cp.go @@ -0,0 +1,703 @@ +// Package lib provides manually-maintained functionality that extends the auto-generated SDK. +package lib + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/fs" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/gorilla/websocket" + "github.com/onkernel/hypeman-go/internal/requestconfig" +) + +// CpConfig holds the configuration needed for copy operations. +// Extract this from a hypeman.Client using ExtractCpConfig. +type CpConfig struct { + // BaseURL is the base URL for the hypeman API + BaseURL string + // APIKey is the JWT token for authentication + APIKey string +} + +// ExtractCpConfig extracts the base URL and API key from client options. +func ExtractCpConfig(opts []requestconfig.RequestOption) (CpConfig, error) { + cfg := &requestconfig.RequestConfig{} + if err := cfg.Apply(opts...); err != nil { + return CpConfig{}, fmt.Errorf("apply options: %w", err) + } + + baseURL := cfg.BaseURL + if baseURL == nil { + baseURL = cfg.DefaultBaseURL + } + if baseURL == nil { + return CpConfig{}, fmt.Errorf("base URL not configured") + } + + return CpConfig{ + BaseURL: baseURL.String(), + APIKey: cfg.APIKey, + }, nil +} + +// CpCallbacks provides optional progress callbacks for copy operations. +type CpCallbacks struct { + OnFileStart func(path string, size int64) // Called when a file starts copying + OnProgress func(bytesCopied int64) // Called as bytes are copied + OnFileEnd func(path string) // Called when a file finishes copying +} + +// CpToInstanceOptions configures a copy-to-instance operation +type CpToInstanceOptions struct { + InstanceID string // Instance ID to copy to + SrcPath string // Local source path + DstPath string // Destination path in guest + Mode fs.FileMode // Optional: override file mode (0 = auto-detect) + Archive bool // Preserve UID/GID ownership + FollowLinks bool // Follow symbolic links when copying + Callbacks *CpCallbacks // Optional: progress callbacks + Dialer WsDialer // Optional: custom WebSocket dialer (for testing) +} + +// CpFromInstanceOptions configures a copy-from-instance operation +type CpFromInstanceOptions struct { + InstanceID string // Instance ID to copy from + SrcPath string // Source path in guest + DstPath string // Local destination path + FollowLinks bool // Follow symbolic links + Archive bool // Preserve UID/GID ownership + Callbacks *CpCallbacks // Optional: progress callbacks + Dialer WsDialer // Optional: custom WebSocket dialer (for testing) +} + +// cpRequest is the JSON request sent over WebSocket +type cpRequest struct { + Direction string `json:"direction"` + GuestPath string `json:"guest_path"` + IsDir bool `json:"is_dir,omitempty"` + Mode uint32 `json:"mode,omitempty"` + FollowLinks bool `json:"follow_links,omitempty"` + Uid uint32 `json:"uid,omitempty"` + Gid uint32 `json:"gid,omitempty"` +} + +// cpFileHeader is received from the server when copying from guest +type cpFileHeader struct { + Type string `json:"type"` + Path string `json:"path"` + Mode uint32 `json:"mode"` + IsDir bool `json:"is_dir"` + IsSymlink bool `json:"is_symlink"` + LinkTarget string `json:"link_target"` + Size int64 `json:"size"` + Mtime int64 `json:"mtime"` + Uid uint32 `json:"uid,omitempty"` + Gid uint32 `json:"gid,omitempty"` +} + +// cpEndMarker signals end of file or transfer +type cpEndMarker struct { + Type string `json:"type"` + Final bool `json:"final"` +} + +// cpResult is the response from a copy-to operation +type cpResult struct { + Type string `json:"type"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` + BytesWritten int64 `json:"bytes_written,omitempty"` +} + +// cpError is an error message from the server +type cpError struct { + Type string `json:"type"` + Message string `json:"message"` + Path string `json:"path,omitempty"` +} + +// CpToInstance copies a file or directory to a running instance. +// +// Example: +// +// cfg, _ := lib.ExtractCpConfig(client.Options) +// err := lib.CpToInstance(ctx, cfg, lib.CpToInstanceOptions{ +// InstanceID: "inst_123", +// SrcPath: "./local-file.txt", +// DstPath: "/app/file.txt", +// }) +func CpToInstance(ctx context.Context, cfg CpConfig, opts CpToInstanceOptions) error { + return cpToInstanceInternal(ctx, cfg, opts, nil) +} + +// cpToInstanceInternal is the internal implementation that accepts visitedDirs for cycle detection +func cpToInstanceInternal(ctx context.Context, cfg CpConfig, opts CpToInstanceOptions, visitedDirs map[string]bool) error { + // Build WebSocket URL + wsURL, err := buildWsURL(cfg.BaseURL, opts.InstanceID) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + + // Connect to WebSocket + headers := http.Header{} + headers.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.APIKey)) + + // Use provided dialer or default + dialer := opts.Dialer + if dialer == nil { + dialer = &DefaultDialer{} + } + + ws, resp, err := dialer.DialContext(ctx, wsURL, headers) + if err != nil { + if resp != nil { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("websocket connect failed (HTTP %d): %s", resp.StatusCode, string(body)) + } + return fmt.Errorf("websocket connect failed: %w", err) + } + defer ws.Close() + + // Stat the source + srcInfo, err := os.Stat(opts.SrcPath) + if err != nil { + return fmt.Errorf("stat source: %w", err) + } + + mode := opts.Mode + if mode == 0 { + mode = srcInfo.Mode().Perm() + } + + // Get UID/GID if archive mode is enabled + var uid, gid uint32 + if opts.Archive { + if stat, ok := srcInfo.Sys().(*syscall.Stat_t); ok { + uid = stat.Uid + gid = stat.Gid + } + } + + // Send initial request + req := cpRequest{ + Direction: "to", + GuestPath: opts.DstPath, + IsDir: srcInfo.IsDir(), + Mode: uint32(mode), + FollowLinks: opts.FollowLinks, + Uid: uid, + Gid: gid, + } + reqJSON, _ := json.Marshal(req) + if err := ws.WriteMessage(websocket.TextMessage, reqJSON); err != nil { + return fmt.Errorf("send request: %w", err) + } + + if srcInfo.IsDir() { + // Track visited directories to detect symlink cycles + if visitedDirs == nil { + visitedDirs = make(map[string]bool) + } + absPath, _ := filepath.Abs(opts.SrcPath) + absPath, _ = filepath.EvalSymlinks(absPath) + if !visitedDirs[absPath] { + visitedDirs[absPath] = true + } + return copyDirToWs(ctx, cfg, ws, opts.SrcPath, opts.DstPath, opts.InstanceID, opts.Archive, opts.FollowLinks, opts.Dialer, opts.Callbacks, visitedDirs) + } + return copyFileToWs(ws, opts.SrcPath, srcInfo.Size(), opts.Callbacks) +} + +// copyFileToWs copies a single file to the WebSocket +func copyFileToWs(ws WsConn, srcPath string, size int64, callbacks *CpCallbacks) error { + file, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("open source: %w", err) + } + defer file.Close() + + // Notify file start + if callbacks != nil && callbacks.OnFileStart != nil { + callbacks.OnFileStart(srcPath, size) + } + + buf := make([]byte, 32*1024) + var bytesSent int64 + for { + n, err := file.Read(buf) + if n > 0 { + if sendErr := ws.WriteMessage(websocket.BinaryMessage, buf[:n]); sendErr != nil { + return fmt.Errorf("send data: %w", sendErr) + } + bytesSent += int64(n) + // Notify progress + if callbacks != nil && callbacks.OnProgress != nil { + callbacks.OnProgress(bytesSent) + } + } + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("read source: %w", err) + } + } + + // Send end marker + endMsg, _ := json.Marshal(map[string]string{"type": "end"}) + if err := ws.WriteMessage(websocket.TextMessage, endMsg); err != nil { + return fmt.Errorf("send end: %w", err) + } + + // Wait for result + _, message, err := ws.ReadMessage() + if err != nil { + return fmt.Errorf("read result: %w", err) + } + + // Check message type first - server may send error or result + var msgType struct { + Type string `json:"type"` + } + if err := json.Unmarshal(message, &msgType); err != nil { + return fmt.Errorf("parse message type: %w", err) + } + + if msgType.Type == "error" { + var errMsg cpError + if err := json.Unmarshal(message, &errMsg); err != nil { + return fmt.Errorf("parse error: %w", err) + } + return fmt.Errorf("copy failed: %s", errMsg.Message) + } + + var result cpResult + if err := json.Unmarshal(message, &result); err != nil { + return fmt.Errorf("parse result: %w", err) + } + + if !result.Success { + return fmt.Errorf("copy failed: %s", result.Error) + } + + // Notify file end (srcPath not available here, use empty string) + if callbacks != nil && callbacks.OnFileEnd != nil { + callbacks.OnFileEnd(srcPath) + } + + return nil +} + +// copyDirToWs copies a directory to the WebSocket +func copyDirToWs(ctx context.Context, cfg CpConfig, ws WsConn, srcPath, dstPath, instanceID string, archive, followLinks bool, dialer WsDialer, callbacks *CpCallbacks, visitedDirs map[string]bool) error { + // For directory copy, we just send the end marker + // The server will create the directory + endMsg, _ := json.Marshal(map[string]string{"type": "end"}) + if err := ws.WriteMessage(websocket.TextMessage, endMsg); err != nil { + return fmt.Errorf("send end: %w", err) + } + + // Wait for result + _, message, err := ws.ReadMessage() + if err != nil { + return fmt.Errorf("read result: %w", err) + } + + // Check message type first - server may send error or result + var msgType struct { + Type string `json:"type"` + } + if err := json.Unmarshal(message, &msgType); err != nil { + return fmt.Errorf("parse message type: %w", err) + } + + if msgType.Type == "error" { + var errMsg cpError + if err := json.Unmarshal(message, &errMsg); err != nil { + return fmt.Errorf("parse error: %w", err) + } + return fmt.Errorf("copy failed: %s", errMsg.Message) + } + + var result cpResult + if err := json.Unmarshal(message, &result); err != nil { + return fmt.Errorf("parse result: %w", err) + } + + if !result.Success { + return fmt.Errorf("copy failed: %s", result.Error) + } + + // Now recursively copy contents + return filepath.WalkDir(srcPath, func(walkPath string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if walkPath == srcPath { + return nil // Skip root + } + + relPath, err := filepath.Rel(srcPath, walkPath) + if err != nil { + return fmt.Errorf("relative path: %w", err) + } + + // Use path.Join (not filepath.Join) for guest paths to ensure forward slashes + // Convert Windows backslashes to forward slashes for Linux guest + targetPath := path.Join(dstPath, filepath.ToSlash(relPath)) + info, err := d.Info() + if err != nil { + return fmt.Errorf("info: %w", err) + } + + // Check for symlink cycles when following links + if followLinks && info.Mode()&fs.ModeSymlink != 0 { + // Resolve the symlink to its real path + realPath, err := filepath.EvalSymlinks(walkPath) + if err != nil { + // If we can't resolve the symlink, skip it (might be broken) + return nil + } + realInfo, err := os.Stat(realPath) + if err != nil { + return nil // Skip broken symlinks + } + // If it's a directory symlink, check for cycles + if realInfo.IsDir() { + if visitedDirs[realPath] { + // Cycle detected, skip this symlink to prevent infinite recursion + return nil + } + visitedDirs[realPath] = true + } + } + + // Determine the mode to use + // For symlinks: if following links, let CpToInstance auto-detect from target + // (symlinks show 0777 but that's not the target's actual mode) + var mode fs.FileMode + if info.Mode()&fs.ModeSymlink != 0 && followLinks { + mode = 0 // Let CpToInstance use os.Stat to get target's mode + } else { + mode = info.Mode().Perm() + } + + // For each file/dir, we need a new WebSocket connection + // This is because the protocol is one-file-per-connection + return cpToInstanceInternal(ctx, cfg, CpToInstanceOptions{ + InstanceID: instanceID, + SrcPath: walkPath, + DstPath: targetPath, + Mode: mode, + Archive: archive, + FollowLinks: followLinks, + Dialer: dialer, + Callbacks: callbacks, + }, visitedDirs) + }) +} + +// CpFromInstance copies a file or directory from a running instance. +// +// Example: +// +// cfg, _ := lib.ExtractCpConfig(client.Options) +// err := lib.CpFromInstance(ctx, cfg, lib.CpFromInstanceOptions{ +// InstanceID: "inst_123", +// SrcPath: "/app/output.txt", +// DstPath: "./local-output.txt", +// }) +func CpFromInstance(ctx context.Context, cfg CpConfig, opts CpFromInstanceOptions) error { + // Build WebSocket URL + wsURL, err := buildWsURL(cfg.BaseURL, opts.InstanceID) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + + // Connect to WebSocket + headers := http.Header{} + headers.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.APIKey)) + + // Use provided dialer or default + dialer := opts.Dialer + if dialer == nil { + dialer = &DefaultDialer{} + } + + ws, resp, err := dialer.DialContext(ctx, wsURL, headers) + if err != nil { + if resp != nil { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("websocket connect failed (HTTP %d): %s", resp.StatusCode, string(body)) + } + return fmt.Errorf("websocket connect failed: %w", err) + } + defer ws.Close() + + // Send initial request + req := cpRequest{ + Direction: "from", + GuestPath: opts.SrcPath, + FollowLinks: opts.FollowLinks, + } + reqJSON, _ := json.Marshal(req) + if err := ws.WriteMessage(websocket.TextMessage, reqJSON); err != nil { + return fmt.Errorf("send request: %w", err) + } + + var currentFile *os.File + var currentHeader *cpFileHeader + var bytesReceived int64 + var receivedFinal bool + + // Ensure any open file is closed on function exit (fixes file handle leak) + defer func() { + if currentFile != nil { + currentFile.Close() + } + }() + + for { + msgType, message, err := ws.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + break + } + return fmt.Errorf("read message: %w", err) + } + + if msgType == websocket.TextMessage { + // Parse JSON message + var msgMap map[string]interface{} + if err := json.Unmarshal(message, &msgMap); err != nil { + return fmt.Errorf("parse message: %w", err) + } + + msgType, _ := msgMap["type"].(string) + + switch msgType { + case "header": + // Close previous file if any + if currentFile != nil { + currentFile.Close() + currentFile = nil + } + + var header cpFileHeader + if err := json.Unmarshal(message, &header); err != nil { + return fmt.Errorf("parse header: %w", err) + } + currentHeader = &header + + // Sanitize server-provided path to prevent path traversal attacks + targetPath, err := sanitizePath(opts.DstPath, header.Path) + if err != nil { + return fmt.Errorf("invalid path from server: %w", err) + } + + if header.IsDir { + if err := os.MkdirAll(targetPath, fs.FileMode(header.Mode)); err != nil { + return fmt.Errorf("create directory %s: %w", targetPath, err) + } + // Apply ownership if archive mode + if opts.Archive { + os.Chown(targetPath, int(header.Uid), int(header.Gid)) + } + } else if header.IsSymlink { + // Validate symlink target to prevent pointing outside destination + if filepath.IsAbs(header.LinkTarget) || strings.HasPrefix(filepath.Clean(header.LinkTarget), "..") { + return fmt.Errorf("invalid symlink target: %s", header.LinkTarget) + } + // Create parent directory if needed + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return fmt.Errorf("create parent dir for symlink: %w", err) + } + os.Remove(targetPath) + if err := os.Symlink(header.LinkTarget, targetPath); err != nil { + return fmt.Errorf("create symlink %s: %w", targetPath, err) + } + // Apply ownership if archive mode (use Lchown for symlinks) + if opts.Archive { + os.Lchown(targetPath, int(header.Uid), int(header.Gid)) + } + } else { + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return fmt.Errorf("create parent dir: %w", err) + } + f, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, fs.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("create file %s: %w", targetPath, err) + } + currentFile = f + // Notify file start + if opts.Callbacks != nil && opts.Callbacks.OnFileStart != nil { + opts.Callbacks.OnFileStart(header.Path, header.Size) + } + } + + case "end": + var endMarker cpEndMarker + if err := json.Unmarshal(message, &endMarker); err != nil { + return fmt.Errorf("invalid end marker: %w", err) + } + + if currentFile != nil { + currentFile.Close() + if currentHeader != nil { + // Path was already validated when file was created + targetPath, _ := sanitizePath(opts.DstPath, currentHeader.Path) + if currentHeader.Mtime > 0 { + mtime := time.Unix(currentHeader.Mtime, 0) + os.Chtimes(targetPath, mtime, mtime) + } + // Apply ownership if archive mode + if opts.Archive { + os.Chown(targetPath, int(currentHeader.Uid), int(currentHeader.Gid)) + } + // Notify file end + if opts.Callbacks != nil && opts.Callbacks.OnFileEnd != nil { + opts.Callbacks.OnFileEnd(currentHeader.Path) + } + } + currentFile = nil + currentHeader = nil + bytesReceived = 0 // Reset for next file + } + + if endMarker.Final { + receivedFinal = true + return nil + } + + case "error": + var cpErr cpError + json.Unmarshal(message, &cpErr) + return fmt.Errorf("copy error at %s: %s", cpErr.Path, cpErr.Message) + + case "result": + var result cpResult + json.Unmarshal(message, &result) + if !result.Success { + return fmt.Errorf("copy failed: %s", result.Error) + } + } + } else if msgType == websocket.BinaryMessage { + // File data + if currentFile != nil { + n, err := currentFile.Write(message) + if err != nil { + return fmt.Errorf("write: %w", err) + } + bytesReceived += int64(n) + // Notify progress + if opts.Callbacks != nil && opts.Callbacks.OnProgress != nil { + opts.Callbacks.OnProgress(bytesReceived) + } + } + } + } + + // If connection closed without receiving final marker, the transfer was incomplete + if !receivedFinal { + return fmt.Errorf("copy stream ended without completion marker") + } + return nil +} + +// sanitizePath ensures the path doesn't escape the base directory. +// This prevents path traversal attacks where server-provided paths contain ".." components. +func sanitizePath(base, path string) (string, error) { + // Clean the path to resolve any . or .. components + cleaned := filepath.Clean(path) + + // Reject absolute paths + if filepath.IsAbs(cleaned) { + return "", fmt.Errorf("invalid path: absolute paths not allowed: %s", path) + } + + // Reject paths that start with .. + if strings.HasPrefix(cleaned, "..") { + return "", fmt.Errorf("invalid path: path escapes destination: %s", path) + } + + // Join with base and verify the result is under base + result := filepath.Join(base, cleaned) + absBase, err := filepath.Abs(base) + if err != nil { + return "", fmt.Errorf("resolve base path: %w", err) + } + absResult, err := filepath.Abs(result) + if err != nil { + return "", fmt.Errorf("resolve result path: %w", err) + } + + // Ensure the result is under the base directory + // Special case: if base is root ("/"), everything under it is valid + isRoot := absBase == "/" || absBase == string(filepath.Separator) + if !isRoot && !strings.HasPrefix(absResult, absBase+string(filepath.Separator)) && absResult != absBase { + return "", fmt.Errorf("invalid path: path escapes destination: %s", path) + } + + return result, nil +} + +// buildWsURL builds the WebSocket URL for the cp endpoint +func buildWsURL(baseURL, instanceID string) (string, error) { + // Validate instanceID to prevent path traversal attacks + if instanceID == "" { + return "", fmt.Errorf("instance ID cannot be empty") + } + if strings.Contains(instanceID, "/") || strings.Contains(instanceID, "\\") || strings.Contains(instanceID, "..") { + return "", fmt.Errorf("invalid instance ID: contains path separator or traversal sequence") + } + + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("invalid base URL: %w", err) + } + + // Append to existing path (preserves any path prefix like /api) + // Use path.Join to handle trailing slashes and ensure clean paths + u.Path = path.Join(u.Path, "instances", instanceID, "cp") + + switch u.Scheme { + case "https": + u.Scheme = "wss" + case "http": + u.Scheme = "ws" + } + + return u.String(), nil +} + +// CpToInstanceFromURL is a convenience function that uses base URL and API key directly. +func CpToInstanceFromURL(ctx context.Context, baseURL, apiKey string, opts CpToInstanceOptions) error { + cfg := CpConfig{ + BaseURL: baseURL, + APIKey: apiKey, + } + return CpToInstance(ctx, cfg, opts) +} + +// CpFromInstanceFromURL is a convenience function that uses base URL and API key directly. +func CpFromInstanceFromURL(ctx context.Context, baseURL, apiKey string, opts CpFromInstanceOptions) error { + cfg := CpConfig{ + BaseURL: baseURL, + APIKey: apiKey, + } + return CpFromInstance(ctx, cfg, opts) +} + diff --git a/lib/cp_test.go b/lib/cp_test.go new file mode 100644 index 0000000..f138d31 --- /dev/null +++ b/lib/cp_test.go @@ -0,0 +1,421 @@ +package lib + +import ( + "context" + "encoding/json" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// wsMessage represents a message to be sent/received +type wsMessage struct { + Type int + Data []byte +} + +// MockWsConn implements WsConn for testing +type MockWsConn struct { + writtenMessages []wsMessage + readQueue []wsMessage + readIndex int + closed bool +} + +func (m *MockWsConn) WriteMessage(messageType int, data []byte) error { + m.writtenMessages = append(m.writtenMessages, wsMessage{Type: messageType, Data: data}) + return nil +} + +func (m *MockWsConn) ReadMessage() (int, []byte, error) { + if m.readIndex >= len(m.readQueue) { + return 0, nil, &websocket.CloseError{Code: websocket.CloseNormalClosure} + } + msg := m.readQueue[m.readIndex] + m.readIndex++ + return msg.Type, msg.Data, nil +} + +func (m *MockWsConn) Close() error { + m.closed = true + return nil +} + +// MockWsDialer implements WsDialer for testing +type MockWsDialer struct { + conn *MockWsConn + dialErr error + dialResp *http.Response +} + +func (d *MockWsDialer) DialContext(ctx context.Context, url string, headers http.Header) (WsConn, *http.Response, error) { + if d.dialErr != nil { + return nil, d.dialResp, d.dialErr + } + return d.conn, nil, nil +} + +// TestSanitizePath tests the path sanitization function +func TestSanitizePath(t *testing.T) { + tests := []struct { + name string + base string + path string + want string + wantErr bool + errMsg string + }{ + { + name: "normal file", + base: "/dest", + path: "file.txt", + want: "/dest/file.txt", + wantErr: false, + }, + { + name: "subdirectory file", + base: "/dest", + path: "sub/dir/file.txt", + want: "/dest/sub/dir/file.txt", + wantErr: false, + }, + { + name: "path traversal attack", + base: "/dest", + path: "../../../etc/passwd", + wantErr: true, + errMsg: "path escapes destination", + }, + { + name: "absolute path attack", + base: "/dest", + path: "/etc/passwd", + wantErr: true, + errMsg: "absolute paths not allowed", + }, + { + name: "dot-dot in middle", + base: "/dest", + path: "sub/../../../etc/passwd", + wantErr: true, + errMsg: "path escapes destination", + }, + { + name: "current dir reference", + base: "/dest", + path: "./file.txt", + want: "/dest/file.txt", + wantErr: false, + }, + { + name: "empty path", + base: "/dest", + path: "", + want: "/dest", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := sanitizePath(tt.base, tt.path) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +// TestCpToInstance_SingleFile tests copying a single file +func TestCpToInstance_SingleFile(t *testing.T) { + // Create a temp file to copy + tmpDir := t.TempDir() + srcFile := filepath.Join(tmpDir, "test.txt") + err := os.WriteFile(srcFile, []byte("hello world"), 0644) + require.NoError(t, err) + + // Create mock that returns success + successResult, _ := json.Marshal(cpResult{Type: "result", Success: true, BytesWritten: 11}) + mockConn := &MockWsConn{ + readQueue: []wsMessage{ + {Type: websocket.TextMessage, Data: successResult}, + }, + } + + mockDialer := &MockWsDialer{conn: mockConn} + + err = CpToInstance(context.Background(), CpConfig{ + BaseURL: "http://localhost:8080", + APIKey: "test-key", + }, CpToInstanceOptions{ + InstanceID: "inst_123", + SrcPath: srcFile, + DstPath: "/app/test.txt", + Dialer: mockDialer, + }) + + require.NoError(t, err) + assert.True(t, mockConn.closed) + + // Verify messages were sent + require.GreaterOrEqual(t, len(mockConn.writtenMessages), 2) // request + data + end + + // First message should be JSON request + var req cpRequest + err = json.Unmarshal(mockConn.writtenMessages[0].Data, &req) + require.NoError(t, err) + assert.Equal(t, "to", req.Direction) + assert.Equal(t, "/app/test.txt", req.GuestPath) + assert.False(t, req.IsDir) +} + +// TestCpFromInstance_PathTraversal tests that path traversal is rejected +func TestCpFromInstance_PathTraversal(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock that sends a malicious header + maliciousHeader, _ := json.Marshal(cpFileHeader{ + Type: "header", + Path: "../../../etc/passwd", + Mode: 0644, + IsDir: false, + Size: 100, + }) + + mockConn := &MockWsConn{ + readQueue: []wsMessage{ + {Type: websocket.TextMessage, Data: maliciousHeader}, + }, + } + + mockDialer := &MockWsDialer{conn: mockConn} + + err := CpFromInstance(context.Background(), CpConfig{ + BaseURL: "http://localhost:8080", + APIKey: "test-key", + }, CpFromInstanceOptions{ + InstanceID: "inst_123", + SrcPath: "/app/file.txt", + DstPath: tmpDir, + Dialer: mockDialer, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid path from server") +} + +// TestCpFromInstance_AbsoluteSymlinkTarget tests that absolute symlink targets are rejected +func TestCpFromInstance_AbsoluteSymlinkTarget(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock that sends a symlink with absolute target + maliciousHeader, _ := json.Marshal(cpFileHeader{ + Type: "header", + Path: "link", + Mode: 0777, + IsSymlink: true, + LinkTarget: "/etc/passwd", + }) + + mockConn := &MockWsConn{ + readQueue: []wsMessage{ + {Type: websocket.TextMessage, Data: maliciousHeader}, + }, + } + + mockDialer := &MockWsDialer{conn: mockConn} + + err := CpFromInstance(context.Background(), CpConfig{ + BaseURL: "http://localhost:8080", + APIKey: "test-key", + }, CpFromInstanceOptions{ + InstanceID: "inst_123", + SrcPath: "/app/", + DstPath: tmpDir, + Dialer: mockDialer, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid symlink target") +} + +// TestCpFromInstance_TraversingSymlinkTarget tests that traversing symlink targets are rejected +func TestCpFromInstance_TraversingSymlinkTarget(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock that sends a symlink with traversing target + maliciousHeader, _ := json.Marshal(cpFileHeader{ + Type: "header", + Path: "link", + Mode: 0777, + IsSymlink: true, + LinkTarget: "../../../etc/passwd", + }) + + mockConn := &MockWsConn{ + readQueue: []wsMessage{ + {Type: websocket.TextMessage, Data: maliciousHeader}, + }, + } + + mockDialer := &MockWsDialer{conn: mockConn} + + err := CpFromInstance(context.Background(), CpConfig{ + BaseURL: "http://localhost:8080", + APIKey: "test-key", + }, CpFromInstanceOptions{ + InstanceID: "inst_123", + SrcPath: "/app/", + DstPath: tmpDir, + Dialer: mockDialer, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid symlink target") +} + +// TestCpFromInstance_NormalFile tests successful file copy from instance +func TestCpFromInstance_NormalFile(t *testing.T) { + tmpDir := t.TempDir() + + // Create mock that sends a valid file + header, _ := json.Marshal(cpFileHeader{ + Type: "header", + Path: "output.txt", + Mode: 0644, + IsDir: false, + Size: 11, + Mtime: 1234567890, + }) + endMarker, _ := json.Marshal(cpEndMarker{Type: "end", Final: true}) + + mockConn := &MockWsConn{ + readQueue: []wsMessage{ + {Type: websocket.TextMessage, Data: header}, + {Type: websocket.BinaryMessage, Data: []byte("hello world")}, + {Type: websocket.TextMessage, Data: endMarker}, + }, + } + + mockDialer := &MockWsDialer{conn: mockConn} + + err := CpFromInstance(context.Background(), CpConfig{ + BaseURL: "http://localhost:8080", + APIKey: "test-key", + }, CpFromInstanceOptions{ + InstanceID: "inst_123", + SrcPath: "/app/output.txt", + DstPath: tmpDir, + Dialer: mockDialer, + }) + + require.NoError(t, err) + + // Verify file was created + content, err := os.ReadFile(filepath.Join(tmpDir, "output.txt")) + require.NoError(t, err) + assert.Equal(t, "hello world", string(content)) +} + +// TestCpCallbacks tests that callbacks are invoked correctly +func TestCpCallbacks(t *testing.T) { + // Create a temp file to copy + tmpDir := t.TempDir() + srcFile := filepath.Join(tmpDir, "test.txt") + err := os.WriteFile(srcFile, []byte("hello world"), 0644) + require.NoError(t, err) + + // Track callback invocations + var fileStartCalled bool + var progressCalled bool + var fileEndCalled bool + var progressBytes int64 + + callbacks := &CpCallbacks{ + OnFileStart: func(path string, size int64) { + fileStartCalled = true + assert.Equal(t, srcFile, path) + assert.Equal(t, int64(11), size) + }, + OnProgress: func(bytesCopied int64) { + progressCalled = true + progressBytes = bytesCopied + }, + OnFileEnd: func(path string) { + fileEndCalled = true + }, + } + + // Create mock that returns success + successResult, _ := json.Marshal(cpResult{Type: "result", Success: true, BytesWritten: 11}) + mockConn := &MockWsConn{ + readQueue: []wsMessage{ + {Type: websocket.TextMessage, Data: successResult}, + }, + } + + mockDialer := &MockWsDialer{conn: mockConn} + + err = CpToInstance(context.Background(), CpConfig{ + BaseURL: "http://localhost:8080", + APIKey: "test-key", + }, CpToInstanceOptions{ + InstanceID: "inst_123", + SrcPath: srcFile, + DstPath: "/app/test.txt", + Callbacks: callbacks, + Dialer: mockDialer, + }) + + require.NoError(t, err) + assert.True(t, fileStartCalled, "OnFileStart should be called") + assert.True(t, progressCalled, "OnProgress should be called") + assert.True(t, fileEndCalled, "OnFileEnd should be called") + assert.Equal(t, int64(11), progressBytes) +} + +// TestBuildWsURL tests WebSocket URL construction +func TestBuildWsURL(t *testing.T) { + tests := []struct { + name string + baseURL string + instanceID string + want string + wantErr bool + }{ + { + name: "https to wss", + baseURL: "https://api.example.com", + instanceID: "inst_123", + want: "wss://api.example.com/instances/inst_123/cp", + }, + { + name: "http to ws", + baseURL: "http://localhost:8080", + instanceID: "inst_456", + want: "ws://localhost:8080/instances/inst_456/cp", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := buildWsURL(tt.baseURL, tt.instanceID) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + diff --git a/lib/ws_connector.go b/lib/ws_connector.go new file mode 100644 index 0000000..8e02e95 --- /dev/null +++ b/lib/ws_connector.go @@ -0,0 +1,38 @@ +// Package lib provides manually-maintained functionality that extends the auto-generated SDK. +package lib + +import ( + "context" + "net/http" + + "github.com/gorilla/websocket" +) + +// WsConn abstracts a WebSocket connection for testing. +type WsConn interface { + WriteMessage(messageType int, data []byte) error + ReadMessage() (messageType int, p []byte, err error) + Close() error +} + +// WsDialer abstracts WebSocket connection creation for testing. +type WsDialer interface { + DialContext(ctx context.Context, url string, headers http.Header) (WsConn, *http.Response, error) +} + +// DefaultDialer uses gorilla/websocket for real WebSocket connections. +type DefaultDialer struct{} + +// DialContext connects to a WebSocket server. +func (d *DefaultDialer) DialContext(ctx context.Context, url string, headers http.Header) (WsConn, *http.Response, error) { + dialer := websocket.Dialer{} + conn, resp, err := dialer.DialContext(ctx, url, headers) + if err != nil { + return nil, resp, err + } + return conn, resp, nil +} + +// Ensure gorilla websocket.Conn implements WsConn +var _ WsConn = (*websocket.Conn)(nil) +