diff --git a/.stats.yml b/.stats.yml
index dbab236..377ee95 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
-configured_endpoints: 24
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/kernel%2Fhypeman-8fded10e90df28c07b64a92d12d665d54749b9fc13c35520667637fc596957d9.yml
-openapi_spec_hash: 7374a732372bddf7f2c0b532b56ae3fb
-config_hash: 510018ffa6ad6a17875954f66fe69598
+configured_endpoints: 30
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/kernel%2Fhypeman-28e78b73c796f9ee866671ed946402b5d569e683c3207d57c9143eb7d6f83fb6.yml
+openapi_spec_hash: fce0ac8713369a5f048bac684ed34fc8
+config_hash: f65a6a2bcef49a9f623212f9de6d6f6f
diff --git a/api.md b/api.md
index 11cb961..cb5f698 100644
--- a/api.md
+++ b/api.md
@@ -30,6 +30,7 @@ Params Types:
Response Types:
- hypeman.Instance
+- hypeman.PathInfo
- hypeman.VolumeMount
Methods:
@@ -42,6 +43,7 @@ Methods:
- client.Instances.Restore(ctx context.Context, id string) (hypeman.Instance, error)
- client.Instances.Standby(ctx context.Context, id string) (hypeman.Instance, error)
- client.Instances.Start(ctx context.Context, id string) (hypeman.Instance, error)
+- client.Instances.Stat(ctx context.Context, id string, query hypeman.InstanceStatParams) (hypeman.PathInfo, error)
- client.Instances.Stop(ctx context.Context, id string) (hypeman.Instance, error)
## Volumes
diff --git a/go.mod b/go.mod
index 52bc772..ca5e04e 100644
--- a/go.mod
+++ b/go.mod
@@ -4,6 +4,8 @@ go 1.24.0
require (
github.com/google/go-containerregistry v0.20.7
+ github.com/gorilla/websocket v1.5.3
+ github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
)
@@ -15,6 +17,7 @@ require (
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/stargz-snapshotter/estargz v0.18.1 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/cli v29.0.3+incompatible // indirect
github.com/docker/distribution v2.8.3+incompatible // indirect
@@ -32,6 +35,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
@@ -49,4 +53,5 @@ require (
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 0e49963..fa2fd63 100644
--- a/go.sum
+++ b/go.sum
@@ -42,10 +42,16 @@ github.com/google/go-containerregistry v0.20.7 h1:24VGNpS0IwrOZ2ms2P1QE3Xa5X9p4p
github.com/google/go-containerregistry v0.20.7/go.mod h1:Lx5LCZQjLH1QBaMPeGwsME9biPeo1lPx6lbGj/UmzgM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
+github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
@@ -66,6 +72,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
+github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -125,6 +133,8 @@ google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/instance.go b/instance.go
index 6d37852..d393460 100644
--- a/instance.go
+++ b/instance.go
@@ -144,6 +144,19 @@ func (r *InstanceService) Start(ctx context.Context, id string, opts ...option.R
return
}
+// Returns information about a path in the guest filesystem. Useful for checking if
+// a path exists, its type, and permissions before performing file operations.
+func (r *InstanceService) Stat(ctx context.Context, id string, query InstanceStatParams, opts ...option.RequestOption) (res *PathInfo, err error) {
+ opts = slices.Concat(r.Options, opts)
+ if id == "" {
+ err = errors.New("missing required id parameter")
+ return
+ }
+ path := fmt.Sprintf("instances/%s/stat", id)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...)
+ return
+}
+
// Stop instance (graceful shutdown)
func (r *InstanceService) Stop(ctx context.Context, id string, opts ...option.RequestOption) (res *Instance, err error) {
opts = slices.Concat(r.Options, opts)
@@ -277,6 +290,45 @@ func (r *InstanceNetwork) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+type PathInfo struct {
+ // Whether the path exists
+ Exists bool `json:"exists,required"`
+ // Error message if stat failed (e.g., permission denied). Only set when exists is
+ // false due to an error rather than the path not existing.
+ Error string `json:"error,nullable"`
+ // True if this is a directory
+ IsDir bool `json:"is_dir"`
+ // True if this is a regular file
+ IsFile bool `json:"is_file"`
+ // True if this is a symbolic link (only set when follow_links=false)
+ IsSymlink bool `json:"is_symlink"`
+ // Symlink target path (only set when is_symlink=true)
+ LinkTarget string `json:"link_target,nullable"`
+ // File mode (Unix permissions)
+ Mode int64 `json:"mode"`
+ // File size in bytes
+ Size int64 `json:"size"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Exists respjson.Field
+ Error respjson.Field
+ IsDir respjson.Field
+ IsFile respjson.Field
+ IsSymlink respjson.Field
+ LinkTarget respjson.Field
+ Mode respjson.Field
+ Size respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r PathInfo) RawJSON() string { return r.JSON.raw }
+func (r *PathInfo) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
type VolumeMount struct {
// Path where volume is mounted in the guest
MountPath string `json:"mount_path,required"`
@@ -354,6 +406,8 @@ type InstanceNewParams struct {
Size param.Opt[string] `json:"size,omitzero"`
// Number of virtual CPUs
Vcpus param.Opt[int64] `json:"vcpus,omitzero"`
+ // Device IDs or names to attach for GPU/PCI passthrough
+ Devices []string `json:"devices,omitzero"`
// Environment variables
Env map[string]string `json:"env,omitzero"`
// Network configuration for the instance
@@ -422,3 +476,19 @@ const (
InstanceLogsParamsSourceVmm InstanceLogsParamsSource = "vmm"
InstanceLogsParamsSourceHypeman InstanceLogsParamsSource = "hypeman"
)
+
+type InstanceStatParams struct {
+ // Path to stat in the guest filesystem
+ Path string `query:"path,required" json:"-"`
+ // Follow symbolic links (like stat vs lstat)
+ FollowLinks param.Opt[bool] `query:"follow_links,omitzero" json:"-"`
+ paramObj
+}
+
+// URLQuery serializes [InstanceStatParams]'s query parameters as `url.Values`.
+func (r InstanceStatParams) URLQuery() (v url.Values, err error) {
+ return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
+ ArrayFormat: apiquery.ArrayQueryFormatComma,
+ NestedFormat: apiquery.NestedQueryFormatBrackets,
+ })
+}
diff --git a/instance_test.go b/instance_test.go
index 4910848..e7642af 100644
--- a/instance_test.go
+++ b/instance_test.go
@@ -195,6 +195,36 @@ func TestInstanceStart(t *testing.T) {
}
}
+func TestInstanceStatWithOptionalParams(t *testing.T) {
+ t.Skip("Prism tests are disabled")
+ baseURL := "http://localhost:4010"
+ if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok {
+ baseURL = envURL
+ }
+ if !testutil.CheckTestServer(t, baseURL) {
+ return
+ }
+ client := hypeman.NewClient(
+ option.WithBaseURL(baseURL),
+ option.WithAPIKey("My API Key"),
+ )
+ _, err := client.Instances.Stat(
+ context.TODO(),
+ "id",
+ hypeman.InstanceStatParams{
+ Path: "path",
+ FollowLinks: hypeman.Bool(true),
+ },
+ )
+ if err != nil {
+ var apierr *hypeman.Error
+ if errors.As(err, &apierr) {
+ t.Log(string(apierr.DumpRequest(true)))
+ }
+ t.Fatalf("err should be nil: %s", err.Error())
+ }
+}
+
func TestInstanceStop(t *testing.T) {
t.Skip("Prism tests are disabled")
baseURL := "http://localhost:4010"
diff --git a/lib/cp.go b/lib/cp.go
new file mode 100644
index 0000000..0d39038
--- /dev/null
+++ b/lib/cp.go
@@ -0,0 +1,703 @@
+// Package lib provides manually-maintained functionality that extends the auto-generated SDK.
+package lib
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/fs"
+ "net/http"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/gorilla/websocket"
+ "github.com/onkernel/hypeman-go/internal/requestconfig"
+)
+
+// CpConfig holds the configuration needed for copy operations.
+// Extract this from a hypeman.Client using ExtractCpConfig.
+type CpConfig struct {
+ // BaseURL is the base URL for the hypeman API
+ BaseURL string
+ // APIKey is the JWT token for authentication
+ APIKey string
+}
+
+// ExtractCpConfig extracts the base URL and API key from client options.
+func ExtractCpConfig(opts []requestconfig.RequestOption) (CpConfig, error) {
+ cfg := &requestconfig.RequestConfig{}
+ if err := cfg.Apply(opts...); err != nil {
+ return CpConfig{}, fmt.Errorf("apply options: %w", err)
+ }
+
+ baseURL := cfg.BaseURL
+ if baseURL == nil {
+ baseURL = cfg.DefaultBaseURL
+ }
+ if baseURL == nil {
+ return CpConfig{}, fmt.Errorf("base URL not configured")
+ }
+
+ return CpConfig{
+ BaseURL: baseURL.String(),
+ APIKey: cfg.APIKey,
+ }, nil
+}
+
+// CpCallbacks provides optional progress callbacks for copy operations.
+type CpCallbacks struct {
+ OnFileStart func(path string, size int64) // Called when a file starts copying
+ OnProgress func(bytesCopied int64) // Called as bytes are copied
+ OnFileEnd func(path string) // Called when a file finishes copying
+}
+
+// CpToInstanceOptions configures a copy-to-instance operation
+type CpToInstanceOptions struct {
+ InstanceID string // Instance ID to copy to
+ SrcPath string // Local source path
+ DstPath string // Destination path in guest
+ Mode fs.FileMode // Optional: override file mode (0 = auto-detect)
+ Archive bool // Preserve UID/GID ownership
+ FollowLinks bool // Follow symbolic links when copying
+ Callbacks *CpCallbacks // Optional: progress callbacks
+ Dialer WsDialer // Optional: custom WebSocket dialer (for testing)
+}
+
+// CpFromInstanceOptions configures a copy-from-instance operation
+type CpFromInstanceOptions struct {
+ InstanceID string // Instance ID to copy from
+ SrcPath string // Source path in guest
+ DstPath string // Local destination path
+ FollowLinks bool // Follow symbolic links
+ Archive bool // Preserve UID/GID ownership
+ Callbacks *CpCallbacks // Optional: progress callbacks
+ Dialer WsDialer // Optional: custom WebSocket dialer (for testing)
+}
+
+// cpRequest is the JSON request sent over WebSocket
+type cpRequest struct {
+ Direction string `json:"direction"`
+ GuestPath string `json:"guest_path"`
+ IsDir bool `json:"is_dir,omitempty"`
+ Mode uint32 `json:"mode,omitempty"`
+ FollowLinks bool `json:"follow_links,omitempty"`
+ Uid uint32 `json:"uid,omitempty"`
+ Gid uint32 `json:"gid,omitempty"`
+}
+
+// cpFileHeader is received from the server when copying from guest
+type cpFileHeader struct {
+ Type string `json:"type"`
+ Path string `json:"path"`
+ Mode uint32 `json:"mode"`
+ IsDir bool `json:"is_dir"`
+ IsSymlink bool `json:"is_symlink"`
+ LinkTarget string `json:"link_target"`
+ Size int64 `json:"size"`
+ Mtime int64 `json:"mtime"`
+ Uid uint32 `json:"uid,omitempty"`
+ Gid uint32 `json:"gid,omitempty"`
+}
+
+// cpEndMarker signals end of file or transfer
+type cpEndMarker struct {
+ Type string `json:"type"`
+ Final bool `json:"final"`
+}
+
+// cpResult is the response from a copy-to operation
+type cpResult struct {
+ Type string `json:"type"`
+ Success bool `json:"success"`
+ Error string `json:"error,omitempty"`
+ BytesWritten int64 `json:"bytes_written,omitempty"`
+}
+
+// cpError is an error message from the server
+type cpError struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+ Path string `json:"path,omitempty"`
+}
+
+// CpToInstance copies a file or directory to a running instance.
+//
+// Example:
+//
+// cfg, _ := lib.ExtractCpConfig(client.Options)
+// err := lib.CpToInstance(ctx, cfg, lib.CpToInstanceOptions{
+// InstanceID: "inst_123",
+// SrcPath: "./local-file.txt",
+// DstPath: "/app/file.txt",
+// })
+func CpToInstance(ctx context.Context, cfg CpConfig, opts CpToInstanceOptions) error {
+ return cpToInstanceInternal(ctx, cfg, opts, nil)
+}
+
+// cpToInstanceInternal is the internal implementation that accepts visitedDirs for cycle detection
+func cpToInstanceInternal(ctx context.Context, cfg CpConfig, opts CpToInstanceOptions, visitedDirs map[string]bool) error {
+ // Build WebSocket URL
+ wsURL, err := buildWsURL(cfg.BaseURL, opts.InstanceID)
+ if err != nil {
+ return fmt.Errorf("build ws url: %w", err)
+ }
+
+ // Connect to WebSocket
+ headers := http.Header{}
+ headers.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.APIKey))
+
+ // Use provided dialer or default
+ dialer := opts.Dialer
+ if dialer == nil {
+ dialer = &DefaultDialer{}
+ }
+
+ ws, resp, err := dialer.DialContext(ctx, wsURL, headers)
+ if err != nil {
+ if resp != nil {
+ defer resp.Body.Close()
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf("websocket connect failed (HTTP %d): %s", resp.StatusCode, string(body))
+ }
+ return fmt.Errorf("websocket connect failed: %w", err)
+ }
+ defer ws.Close()
+
+ // Stat the source
+ srcInfo, err := os.Stat(opts.SrcPath)
+ if err != nil {
+ return fmt.Errorf("stat source: %w", err)
+ }
+
+ mode := opts.Mode
+ if mode == 0 {
+ mode = srcInfo.Mode().Perm()
+ }
+
+ // Get UID/GID if archive mode is enabled
+ var uid, gid uint32
+ if opts.Archive {
+ if stat, ok := srcInfo.Sys().(*syscall.Stat_t); ok {
+ uid = stat.Uid
+ gid = stat.Gid
+ }
+ }
+
+ // Send initial request
+ req := cpRequest{
+ Direction: "to",
+ GuestPath: opts.DstPath,
+ IsDir: srcInfo.IsDir(),
+ Mode: uint32(mode),
+ FollowLinks: opts.FollowLinks,
+ Uid: uid,
+ Gid: gid,
+ }
+ reqJSON, _ := json.Marshal(req)
+ if err := ws.WriteMessage(websocket.TextMessage, reqJSON); err != nil {
+ return fmt.Errorf("send request: %w", err)
+ }
+
+ if srcInfo.IsDir() {
+ // Track visited directories to detect symlink cycles
+ if visitedDirs == nil {
+ visitedDirs = make(map[string]bool)
+ }
+ absPath, _ := filepath.Abs(opts.SrcPath)
+ absPath, _ = filepath.EvalSymlinks(absPath)
+ if !visitedDirs[absPath] {
+ visitedDirs[absPath] = true
+ }
+ return copyDirToWs(ctx, cfg, ws, opts.SrcPath, opts.DstPath, opts.InstanceID, opts.Archive, opts.FollowLinks, opts.Dialer, opts.Callbacks, visitedDirs)
+ }
+ return copyFileToWs(ws, opts.SrcPath, srcInfo.Size(), opts.Callbacks)
+}
+
+// copyFileToWs copies a single file to the WebSocket
+func copyFileToWs(ws WsConn, srcPath string, size int64, callbacks *CpCallbacks) error {
+ file, err := os.Open(srcPath)
+ if err != nil {
+ return fmt.Errorf("open source: %w", err)
+ }
+ defer file.Close()
+
+ // Notify file start
+ if callbacks != nil && callbacks.OnFileStart != nil {
+ callbacks.OnFileStart(srcPath, size)
+ }
+
+ buf := make([]byte, 32*1024)
+ var bytesSent int64
+ for {
+ n, err := file.Read(buf)
+ if n > 0 {
+ if sendErr := ws.WriteMessage(websocket.BinaryMessage, buf[:n]); sendErr != nil {
+ return fmt.Errorf("send data: %w", sendErr)
+ }
+ bytesSent += int64(n)
+ // Notify progress
+ if callbacks != nil && callbacks.OnProgress != nil {
+ callbacks.OnProgress(bytesSent)
+ }
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return fmt.Errorf("read source: %w", err)
+ }
+ }
+
+ // Send end marker
+ endMsg, _ := json.Marshal(map[string]string{"type": "end"})
+ if err := ws.WriteMessage(websocket.TextMessage, endMsg); err != nil {
+ return fmt.Errorf("send end: %w", err)
+ }
+
+ // Wait for result
+ _, message, err := ws.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("read result: %w", err)
+ }
+
+ // Check message type first - server may send error or result
+ var msgType struct {
+ Type string `json:"type"`
+ }
+ if err := json.Unmarshal(message, &msgType); err != nil {
+ return fmt.Errorf("parse message type: %w", err)
+ }
+
+ if msgType.Type == "error" {
+ var errMsg cpError
+ if err := json.Unmarshal(message, &errMsg); err != nil {
+ return fmt.Errorf("parse error: %w", err)
+ }
+ return fmt.Errorf("copy failed: %s", errMsg.Message)
+ }
+
+ var result cpResult
+ if err := json.Unmarshal(message, &result); err != nil {
+ return fmt.Errorf("parse result: %w", err)
+ }
+
+ if !result.Success {
+ return fmt.Errorf("copy failed: %s", result.Error)
+ }
+
+ // Notify file end (srcPath not available here, use empty string)
+ if callbacks != nil && callbacks.OnFileEnd != nil {
+ callbacks.OnFileEnd(srcPath)
+ }
+
+ return nil
+}
+
+// copyDirToWs copies a directory to the WebSocket
+func copyDirToWs(ctx context.Context, cfg CpConfig, ws WsConn, srcPath, dstPath, instanceID string, archive, followLinks bool, dialer WsDialer, callbacks *CpCallbacks, visitedDirs map[string]bool) error {
+ // For directory copy, we just send the end marker
+ // The server will create the directory
+ endMsg, _ := json.Marshal(map[string]string{"type": "end"})
+ if err := ws.WriteMessage(websocket.TextMessage, endMsg); err != nil {
+ return fmt.Errorf("send end: %w", err)
+ }
+
+ // Wait for result
+ _, message, err := ws.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("read result: %w", err)
+ }
+
+ // Check message type first - server may send error or result
+ var msgType struct {
+ Type string `json:"type"`
+ }
+ if err := json.Unmarshal(message, &msgType); err != nil {
+ return fmt.Errorf("parse message type: %w", err)
+ }
+
+ if msgType.Type == "error" {
+ var errMsg cpError
+ if err := json.Unmarshal(message, &errMsg); err != nil {
+ return fmt.Errorf("parse error: %w", err)
+ }
+ return fmt.Errorf("copy failed: %s", errMsg.Message)
+ }
+
+ var result cpResult
+ if err := json.Unmarshal(message, &result); err != nil {
+ return fmt.Errorf("parse result: %w", err)
+ }
+
+ if !result.Success {
+ return fmt.Errorf("copy failed: %s", result.Error)
+ }
+
+ // Now recursively copy contents
+ return filepath.WalkDir(srcPath, func(walkPath string, d fs.DirEntry, err error) error {
+ if err != nil {
+ return err
+ }
+
+ if walkPath == srcPath {
+ return nil // Skip root
+ }
+
+ relPath, err := filepath.Rel(srcPath, walkPath)
+ if err != nil {
+ return fmt.Errorf("relative path: %w", err)
+ }
+
+ // Use path.Join (not filepath.Join) for guest paths to ensure forward slashes
+ // Convert Windows backslashes to forward slashes for Linux guest
+ targetPath := path.Join(dstPath, filepath.ToSlash(relPath))
+ info, err := d.Info()
+ if err != nil {
+ return fmt.Errorf("info: %w", err)
+ }
+
+ // Check for symlink cycles when following links
+ if followLinks && info.Mode()&fs.ModeSymlink != 0 {
+ // Resolve the symlink to its real path
+ realPath, err := filepath.EvalSymlinks(walkPath)
+ if err != nil {
+ // If we can't resolve the symlink, skip it (might be broken)
+ return nil
+ }
+ realInfo, err := os.Stat(realPath)
+ if err != nil {
+ return nil // Skip broken symlinks
+ }
+ // If it's a directory symlink, check for cycles
+ if realInfo.IsDir() {
+ if visitedDirs[realPath] {
+ // Cycle detected, skip this symlink to prevent infinite recursion
+ return nil
+ }
+ visitedDirs[realPath] = true
+ }
+ }
+
+ // Determine the mode to use
+ // For symlinks: if following links, let CpToInstance auto-detect from target
+ // (symlinks show 0777 but that's not the target's actual mode)
+ var mode fs.FileMode
+ if info.Mode()&fs.ModeSymlink != 0 && followLinks {
+ mode = 0 // Let CpToInstance use os.Stat to get target's mode
+ } else {
+ mode = info.Mode().Perm()
+ }
+
+ // For each file/dir, we need a new WebSocket connection
+ // This is because the protocol is one-file-per-connection
+ return cpToInstanceInternal(ctx, cfg, CpToInstanceOptions{
+ InstanceID: instanceID,
+ SrcPath: walkPath,
+ DstPath: targetPath,
+ Mode: mode,
+ Archive: archive,
+ FollowLinks: followLinks,
+ Dialer: dialer,
+ Callbacks: callbacks,
+ }, visitedDirs)
+ })
+}
+
+// CpFromInstance copies a file or directory from a running instance.
+//
+// Example:
+//
+// cfg, _ := lib.ExtractCpConfig(client.Options)
+// err := lib.CpFromInstance(ctx, cfg, lib.CpFromInstanceOptions{
+// InstanceID: "inst_123",
+// SrcPath: "/app/output.txt",
+// DstPath: "./local-output.txt",
+// })
+func CpFromInstance(ctx context.Context, cfg CpConfig, opts CpFromInstanceOptions) error {
+ // Build WebSocket URL
+ wsURL, err := buildWsURL(cfg.BaseURL, opts.InstanceID)
+ if err != nil {
+ return fmt.Errorf("build ws url: %w", err)
+ }
+
+ // Connect to WebSocket
+ headers := http.Header{}
+ headers.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.APIKey))
+
+ // Use provided dialer or default
+ dialer := opts.Dialer
+ if dialer == nil {
+ dialer = &DefaultDialer{}
+ }
+
+ ws, resp, err := dialer.DialContext(ctx, wsURL, headers)
+ if err != nil {
+ if resp != nil {
+ defer resp.Body.Close()
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf("websocket connect failed (HTTP %d): %s", resp.StatusCode, string(body))
+ }
+ return fmt.Errorf("websocket connect failed: %w", err)
+ }
+ defer ws.Close()
+
+ // Send initial request
+ req := cpRequest{
+ Direction: "from",
+ GuestPath: opts.SrcPath,
+ FollowLinks: opts.FollowLinks,
+ }
+ reqJSON, _ := json.Marshal(req)
+ if err := ws.WriteMessage(websocket.TextMessage, reqJSON); err != nil {
+ return fmt.Errorf("send request: %w", err)
+ }
+
+ var currentFile *os.File
+ var currentHeader *cpFileHeader
+ var bytesReceived int64
+ var receivedFinal bool
+
+ // Ensure any open file is closed on function exit (fixes file handle leak)
+ defer func() {
+ if currentFile != nil {
+ currentFile.Close()
+ }
+ }()
+
+ for {
+ msgType, message, err := ws.ReadMessage()
+ if err != nil {
+ if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+ break
+ }
+ return fmt.Errorf("read message: %w", err)
+ }
+
+ if msgType == websocket.TextMessage {
+ // Parse JSON message
+ var msgMap map[string]interface{}
+ if err := json.Unmarshal(message, &msgMap); err != nil {
+ return fmt.Errorf("parse message: %w", err)
+ }
+
+ msgType, _ := msgMap["type"].(string)
+
+ switch msgType {
+ case "header":
+ // Close previous file if any
+ if currentFile != nil {
+ currentFile.Close()
+ currentFile = nil
+ }
+
+ var header cpFileHeader
+ if err := json.Unmarshal(message, &header); err != nil {
+ return fmt.Errorf("parse header: %w", err)
+ }
+ currentHeader = &header
+
+ // Sanitize server-provided path to prevent path traversal attacks
+ targetPath, err := sanitizePath(opts.DstPath, header.Path)
+ if err != nil {
+ return fmt.Errorf("invalid path from server: %w", err)
+ }
+
+ if header.IsDir {
+ if err := os.MkdirAll(targetPath, fs.FileMode(header.Mode)); err != nil {
+ return fmt.Errorf("create directory %s: %w", targetPath, err)
+ }
+ // Apply ownership if archive mode
+ if opts.Archive {
+ os.Chown(targetPath, int(header.Uid), int(header.Gid))
+ }
+ } else if header.IsSymlink {
+ // Validate symlink target to prevent pointing outside destination
+ if filepath.IsAbs(header.LinkTarget) || strings.HasPrefix(filepath.Clean(header.LinkTarget), "..") {
+ return fmt.Errorf("invalid symlink target: %s", header.LinkTarget)
+ }
+ // Create parent directory if needed
+ if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
+ return fmt.Errorf("create parent dir for symlink: %w", err)
+ }
+ os.Remove(targetPath)
+ if err := os.Symlink(header.LinkTarget, targetPath); err != nil {
+ return fmt.Errorf("create symlink %s: %w", targetPath, err)
+ }
+ // Apply ownership if archive mode (use Lchown for symlinks)
+ if opts.Archive {
+ os.Lchown(targetPath, int(header.Uid), int(header.Gid))
+ }
+ } else {
+ if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
+ return fmt.Errorf("create parent dir: %w", err)
+ }
+ f, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, fs.FileMode(header.Mode))
+ if err != nil {
+ return fmt.Errorf("create file %s: %w", targetPath, err)
+ }
+ currentFile = f
+ // Notify file start
+ if opts.Callbacks != nil && opts.Callbacks.OnFileStart != nil {
+ opts.Callbacks.OnFileStart(header.Path, header.Size)
+ }
+ }
+
+ case "end":
+ var endMarker cpEndMarker
+ if err := json.Unmarshal(message, &endMarker); err != nil {
+ return fmt.Errorf("invalid end marker: %w", err)
+ }
+
+ if currentFile != nil {
+ currentFile.Close()
+ if currentHeader != nil {
+ // Path was already validated when file was created
+ targetPath, _ := sanitizePath(opts.DstPath, currentHeader.Path)
+ if currentHeader.Mtime > 0 {
+ mtime := time.Unix(currentHeader.Mtime, 0)
+ os.Chtimes(targetPath, mtime, mtime)
+ }
+ // Apply ownership if archive mode
+ if opts.Archive {
+ os.Chown(targetPath, int(currentHeader.Uid), int(currentHeader.Gid))
+ }
+ // Notify file end
+ if opts.Callbacks != nil && opts.Callbacks.OnFileEnd != nil {
+ opts.Callbacks.OnFileEnd(currentHeader.Path)
+ }
+ }
+ currentFile = nil
+ currentHeader = nil
+ bytesReceived = 0 // Reset for next file
+ }
+
+ if endMarker.Final {
+ receivedFinal = true
+ return nil
+ }
+
+ case "error":
+ var cpErr cpError
+ json.Unmarshal(message, &cpErr)
+ return fmt.Errorf("copy error at %s: %s", cpErr.Path, cpErr.Message)
+
+ case "result":
+ var result cpResult
+ json.Unmarshal(message, &result)
+ if !result.Success {
+ return fmt.Errorf("copy failed: %s", result.Error)
+ }
+ }
+ } else if msgType == websocket.BinaryMessage {
+ // File data
+ if currentFile != nil {
+ n, err := currentFile.Write(message)
+ if err != nil {
+ return fmt.Errorf("write: %w", err)
+ }
+ bytesReceived += int64(n)
+ // Notify progress
+ if opts.Callbacks != nil && opts.Callbacks.OnProgress != nil {
+ opts.Callbacks.OnProgress(bytesReceived)
+ }
+ }
+ }
+ }
+
+ // If connection closed without receiving final marker, the transfer was incomplete
+ if !receivedFinal {
+ return fmt.Errorf("copy stream ended without completion marker")
+ }
+ return nil
+}
+
+// sanitizePath ensures the path doesn't escape the base directory.
+// This prevents path traversal attacks where server-provided paths contain ".." components.
+func sanitizePath(base, path string) (string, error) {
+ // Clean the path to resolve any . or .. components
+ cleaned := filepath.Clean(path)
+
+ // Reject absolute paths
+ if filepath.IsAbs(cleaned) {
+ return "", fmt.Errorf("invalid path: absolute paths not allowed: %s", path)
+ }
+
+ // Reject paths that start with ..
+ if strings.HasPrefix(cleaned, "..") {
+ return "", fmt.Errorf("invalid path: path escapes destination: %s", path)
+ }
+
+ // Join with base and verify the result is under base
+ result := filepath.Join(base, cleaned)
+ absBase, err := filepath.Abs(base)
+ if err != nil {
+ return "", fmt.Errorf("resolve base path: %w", err)
+ }
+ absResult, err := filepath.Abs(result)
+ if err != nil {
+ return "", fmt.Errorf("resolve result path: %w", err)
+ }
+
+ // Ensure the result is under the base directory
+ // Special case: if base is root ("/"), everything under it is valid
+ isRoot := absBase == "/" || absBase == string(filepath.Separator)
+ if !isRoot && !strings.HasPrefix(absResult, absBase+string(filepath.Separator)) && absResult != absBase {
+ return "", fmt.Errorf("invalid path: path escapes destination: %s", path)
+ }
+
+ return result, nil
+}
+
+// buildWsURL builds the WebSocket URL for the cp endpoint
+func buildWsURL(baseURL, instanceID string) (string, error) {
+ // Validate instanceID to prevent path traversal attacks
+ if instanceID == "" {
+ return "", fmt.Errorf("instance ID cannot be empty")
+ }
+ if strings.Contains(instanceID, "/") || strings.Contains(instanceID, "\\") || strings.Contains(instanceID, "..") {
+ return "", fmt.Errorf("invalid instance ID: contains path separator or traversal sequence")
+ }
+
+ u, err := url.Parse(baseURL)
+ if err != nil {
+ return "", fmt.Errorf("invalid base URL: %w", err)
+ }
+
+ // Append to existing path (preserves any path prefix like /api)
+ // Use path.Join to handle trailing slashes and ensure clean paths
+ u.Path = path.Join(u.Path, "instances", instanceID, "cp")
+
+ switch u.Scheme {
+ case "https":
+ u.Scheme = "wss"
+ case "http":
+ u.Scheme = "ws"
+ }
+
+ return u.String(), nil
+}
+
+// CpToInstanceFromURL is a convenience function that uses base URL and API key directly.
+func CpToInstanceFromURL(ctx context.Context, baseURL, apiKey string, opts CpToInstanceOptions) error {
+ cfg := CpConfig{
+ BaseURL: baseURL,
+ APIKey: apiKey,
+ }
+ return CpToInstance(ctx, cfg, opts)
+}
+
+// CpFromInstanceFromURL is a convenience function that uses base URL and API key directly.
+func CpFromInstanceFromURL(ctx context.Context, baseURL, apiKey string, opts CpFromInstanceOptions) error {
+ cfg := CpConfig{
+ BaseURL: baseURL,
+ APIKey: apiKey,
+ }
+ return CpFromInstance(ctx, cfg, opts)
+}
+
diff --git a/lib/cp_test.go b/lib/cp_test.go
new file mode 100644
index 0000000..f138d31
--- /dev/null
+++ b/lib/cp_test.go
@@ -0,0 +1,421 @@
+package lib
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/gorilla/websocket"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// wsMessage represents a message to be sent/received
+type wsMessage struct {
+ Type int
+ Data []byte
+}
+
+// MockWsConn implements WsConn for testing
+type MockWsConn struct {
+ writtenMessages []wsMessage
+ readQueue []wsMessage
+ readIndex int
+ closed bool
+}
+
+func (m *MockWsConn) WriteMessage(messageType int, data []byte) error {
+ m.writtenMessages = append(m.writtenMessages, wsMessage{Type: messageType, Data: data})
+ return nil
+}
+
+func (m *MockWsConn) ReadMessage() (int, []byte, error) {
+ if m.readIndex >= len(m.readQueue) {
+ return 0, nil, &websocket.CloseError{Code: websocket.CloseNormalClosure}
+ }
+ msg := m.readQueue[m.readIndex]
+ m.readIndex++
+ return msg.Type, msg.Data, nil
+}
+
+func (m *MockWsConn) Close() error {
+ m.closed = true
+ return nil
+}
+
+// MockWsDialer implements WsDialer for testing
+type MockWsDialer struct {
+ conn *MockWsConn
+ dialErr error
+ dialResp *http.Response
+}
+
+func (d *MockWsDialer) DialContext(ctx context.Context, url string, headers http.Header) (WsConn, *http.Response, error) {
+ if d.dialErr != nil {
+ return nil, d.dialResp, d.dialErr
+ }
+ return d.conn, nil, nil
+}
+
+// TestSanitizePath tests the path sanitization function
+func TestSanitizePath(t *testing.T) {
+ tests := []struct {
+ name string
+ base string
+ path string
+ want string
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "normal file",
+ base: "/dest",
+ path: "file.txt",
+ want: "/dest/file.txt",
+ wantErr: false,
+ },
+ {
+ name: "subdirectory file",
+ base: "/dest",
+ path: "sub/dir/file.txt",
+ want: "/dest/sub/dir/file.txt",
+ wantErr: false,
+ },
+ {
+ name: "path traversal attack",
+ base: "/dest",
+ path: "../../../etc/passwd",
+ wantErr: true,
+ errMsg: "path escapes destination",
+ },
+ {
+ name: "absolute path attack",
+ base: "/dest",
+ path: "/etc/passwd",
+ wantErr: true,
+ errMsg: "absolute paths not allowed",
+ },
+ {
+ name: "dot-dot in middle",
+ base: "/dest",
+ path: "sub/../../../etc/passwd",
+ wantErr: true,
+ errMsg: "path escapes destination",
+ },
+ {
+ name: "current dir reference",
+ base: "/dest",
+ path: "./file.txt",
+ want: "/dest/file.txt",
+ wantErr: false,
+ },
+ {
+ name: "empty path",
+ base: "/dest",
+ path: "",
+ want: "/dest",
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := sanitizePath(tt.base, tt.path)
+ if tt.wantErr {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMsg)
+ } else {
+ require.NoError(t, err)
+ assert.Equal(t, tt.want, got)
+ }
+ })
+ }
+}
+
+// TestCpToInstance_SingleFile tests copying a single file
+func TestCpToInstance_SingleFile(t *testing.T) {
+ // Create a temp file to copy
+ tmpDir := t.TempDir()
+ srcFile := filepath.Join(tmpDir, "test.txt")
+ err := os.WriteFile(srcFile, []byte("hello world"), 0644)
+ require.NoError(t, err)
+
+ // Create mock that returns success
+ successResult, _ := json.Marshal(cpResult{Type: "result", Success: true, BytesWritten: 11})
+ mockConn := &MockWsConn{
+ readQueue: []wsMessage{
+ {Type: websocket.TextMessage, Data: successResult},
+ },
+ }
+
+ mockDialer := &MockWsDialer{conn: mockConn}
+
+ err = CpToInstance(context.Background(), CpConfig{
+ BaseURL: "http://localhost:8080",
+ APIKey: "test-key",
+ }, CpToInstanceOptions{
+ InstanceID: "inst_123",
+ SrcPath: srcFile,
+ DstPath: "/app/test.txt",
+ Dialer: mockDialer,
+ })
+
+ require.NoError(t, err)
+ assert.True(t, mockConn.closed)
+
+ // Verify messages were sent
+ require.GreaterOrEqual(t, len(mockConn.writtenMessages), 2) // request + data + end
+
+ // First message should be JSON request
+ var req cpRequest
+ err = json.Unmarshal(mockConn.writtenMessages[0].Data, &req)
+ require.NoError(t, err)
+ assert.Equal(t, "to", req.Direction)
+ assert.Equal(t, "/app/test.txt", req.GuestPath)
+ assert.False(t, req.IsDir)
+}
+
+// TestCpFromInstance_PathTraversal tests that path traversal is rejected
+func TestCpFromInstance_PathTraversal(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create mock that sends a malicious header
+ maliciousHeader, _ := json.Marshal(cpFileHeader{
+ Type: "header",
+ Path: "../../../etc/passwd",
+ Mode: 0644,
+ IsDir: false,
+ Size: 100,
+ })
+
+ mockConn := &MockWsConn{
+ readQueue: []wsMessage{
+ {Type: websocket.TextMessage, Data: maliciousHeader},
+ },
+ }
+
+ mockDialer := &MockWsDialer{conn: mockConn}
+
+ err := CpFromInstance(context.Background(), CpConfig{
+ BaseURL: "http://localhost:8080",
+ APIKey: "test-key",
+ }, CpFromInstanceOptions{
+ InstanceID: "inst_123",
+ SrcPath: "/app/file.txt",
+ DstPath: tmpDir,
+ Dialer: mockDialer,
+ })
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid path from server")
+}
+
+// TestCpFromInstance_AbsoluteSymlinkTarget tests that absolute symlink targets are rejected
+func TestCpFromInstance_AbsoluteSymlinkTarget(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create mock that sends a symlink with absolute target
+ maliciousHeader, _ := json.Marshal(cpFileHeader{
+ Type: "header",
+ Path: "link",
+ Mode: 0777,
+ IsSymlink: true,
+ LinkTarget: "/etc/passwd",
+ })
+
+ mockConn := &MockWsConn{
+ readQueue: []wsMessage{
+ {Type: websocket.TextMessage, Data: maliciousHeader},
+ },
+ }
+
+ mockDialer := &MockWsDialer{conn: mockConn}
+
+ err := CpFromInstance(context.Background(), CpConfig{
+ BaseURL: "http://localhost:8080",
+ APIKey: "test-key",
+ }, CpFromInstanceOptions{
+ InstanceID: "inst_123",
+ SrcPath: "/app/",
+ DstPath: tmpDir,
+ Dialer: mockDialer,
+ })
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid symlink target")
+}
+
+// TestCpFromInstance_TraversingSymlinkTarget tests that traversing symlink targets are rejected
+func TestCpFromInstance_TraversingSymlinkTarget(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create mock that sends a symlink with traversing target
+ maliciousHeader, _ := json.Marshal(cpFileHeader{
+ Type: "header",
+ Path: "link",
+ Mode: 0777,
+ IsSymlink: true,
+ LinkTarget: "../../../etc/passwd",
+ })
+
+ mockConn := &MockWsConn{
+ readQueue: []wsMessage{
+ {Type: websocket.TextMessage, Data: maliciousHeader},
+ },
+ }
+
+ mockDialer := &MockWsDialer{conn: mockConn}
+
+ err := CpFromInstance(context.Background(), CpConfig{
+ BaseURL: "http://localhost:8080",
+ APIKey: "test-key",
+ }, CpFromInstanceOptions{
+ InstanceID: "inst_123",
+ SrcPath: "/app/",
+ DstPath: tmpDir,
+ Dialer: mockDialer,
+ })
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid symlink target")
+}
+
+// TestCpFromInstance_NormalFile tests successful file copy from instance
+func TestCpFromInstance_NormalFile(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create mock that sends a valid file
+ header, _ := json.Marshal(cpFileHeader{
+ Type: "header",
+ Path: "output.txt",
+ Mode: 0644,
+ IsDir: false,
+ Size: 11,
+ Mtime: 1234567890,
+ })
+ endMarker, _ := json.Marshal(cpEndMarker{Type: "end", Final: true})
+
+ mockConn := &MockWsConn{
+ readQueue: []wsMessage{
+ {Type: websocket.TextMessage, Data: header},
+ {Type: websocket.BinaryMessage, Data: []byte("hello world")},
+ {Type: websocket.TextMessage, Data: endMarker},
+ },
+ }
+
+ mockDialer := &MockWsDialer{conn: mockConn}
+
+ err := CpFromInstance(context.Background(), CpConfig{
+ BaseURL: "http://localhost:8080",
+ APIKey: "test-key",
+ }, CpFromInstanceOptions{
+ InstanceID: "inst_123",
+ SrcPath: "/app/output.txt",
+ DstPath: tmpDir,
+ Dialer: mockDialer,
+ })
+
+ require.NoError(t, err)
+
+ // Verify file was created
+ content, err := os.ReadFile(filepath.Join(tmpDir, "output.txt"))
+ require.NoError(t, err)
+ assert.Equal(t, "hello world", string(content))
+}
+
+// TestCpCallbacks tests that callbacks are invoked correctly
+func TestCpCallbacks(t *testing.T) {
+ // Create a temp file to copy
+ tmpDir := t.TempDir()
+ srcFile := filepath.Join(tmpDir, "test.txt")
+ err := os.WriteFile(srcFile, []byte("hello world"), 0644)
+ require.NoError(t, err)
+
+ // Track callback invocations
+ var fileStartCalled bool
+ var progressCalled bool
+ var fileEndCalled bool
+ var progressBytes int64
+
+ callbacks := &CpCallbacks{
+ OnFileStart: func(path string, size int64) {
+ fileStartCalled = true
+ assert.Equal(t, srcFile, path)
+ assert.Equal(t, int64(11), size)
+ },
+ OnProgress: func(bytesCopied int64) {
+ progressCalled = true
+ progressBytes = bytesCopied
+ },
+ OnFileEnd: func(path string) {
+ fileEndCalled = true
+ },
+ }
+
+ // Create mock that returns success
+ successResult, _ := json.Marshal(cpResult{Type: "result", Success: true, BytesWritten: 11})
+ mockConn := &MockWsConn{
+ readQueue: []wsMessage{
+ {Type: websocket.TextMessage, Data: successResult},
+ },
+ }
+
+ mockDialer := &MockWsDialer{conn: mockConn}
+
+ err = CpToInstance(context.Background(), CpConfig{
+ BaseURL: "http://localhost:8080",
+ APIKey: "test-key",
+ }, CpToInstanceOptions{
+ InstanceID: "inst_123",
+ SrcPath: srcFile,
+ DstPath: "/app/test.txt",
+ Callbacks: callbacks,
+ Dialer: mockDialer,
+ })
+
+ require.NoError(t, err)
+ assert.True(t, fileStartCalled, "OnFileStart should be called")
+ assert.True(t, progressCalled, "OnProgress should be called")
+ assert.True(t, fileEndCalled, "OnFileEnd should be called")
+ assert.Equal(t, int64(11), progressBytes)
+}
+
+// TestBuildWsURL tests WebSocket URL construction
+func TestBuildWsURL(t *testing.T) {
+ tests := []struct {
+ name string
+ baseURL string
+ instanceID string
+ want string
+ wantErr bool
+ }{
+ {
+ name: "https to wss",
+ baseURL: "https://api.example.com",
+ instanceID: "inst_123",
+ want: "wss://api.example.com/instances/inst_123/cp",
+ },
+ {
+ name: "http to ws",
+ baseURL: "http://localhost:8080",
+ instanceID: "inst_456",
+ want: "ws://localhost:8080/instances/inst_456/cp",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := buildWsURL(tt.baseURL, tt.instanceID)
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ assert.Equal(t, tt.want, got)
+ }
+ })
+ }
+}
+
diff --git a/lib/ws_connector.go b/lib/ws_connector.go
new file mode 100644
index 0000000..8e02e95
--- /dev/null
+++ b/lib/ws_connector.go
@@ -0,0 +1,38 @@
+// Package lib provides manually-maintained functionality that extends the auto-generated SDK.
+package lib
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/gorilla/websocket"
+)
+
+// WsConn abstracts a WebSocket connection for testing.
+type WsConn interface {
+ WriteMessage(messageType int, data []byte) error
+ ReadMessage() (messageType int, p []byte, err error)
+ Close() error
+}
+
+// WsDialer abstracts WebSocket connection creation for testing.
+type WsDialer interface {
+ DialContext(ctx context.Context, url string, headers http.Header) (WsConn, *http.Response, error)
+}
+
+// DefaultDialer uses gorilla/websocket for real WebSocket connections.
+type DefaultDialer struct{}
+
+// DialContext connects to a WebSocket server.
+func (d *DefaultDialer) DialContext(ctx context.Context, url string, headers http.Header) (WsConn, *http.Response, error) {
+ dialer := websocket.Dialer{}
+ conn, resp, err := dialer.DialContext(ctx, url, headers)
+ if err != nil {
+ return nil, resp, err
+ }
+ return conn, resp, nil
+}
+
+// Ensure gorilla websocket.Conn implements WsConn
+var _ WsConn = (*websocket.Conn)(nil)
+