diff --git a/cmd/container-use/apply.go b/cmd/container-use/apply.go index bac9ae7b..cf440eba 100644 --- a/cmd/container-use/apply.go +++ b/cmd/container-use/apply.go @@ -43,7 +43,7 @@ git commit -m "Add backend API implementation"`, env := args[0] - if err := repo.Apply(ctx, env, os.Stdout); err != nil { + if err := repo.Apply(ctx, env, os.Stdout, false); err != nil { return fmt.Errorf("failed to apply environment: %w", err) } diff --git a/cmd/container-use/list.go b/cmd/container-use/list.go index 896cbbb5..b5583b85 100644 --- a/cmd/container-use/list.go +++ b/cmd/container-use/list.go @@ -33,11 +33,11 @@ Use -q for environment IDs only, useful for scripting.`, } tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) - fmt.Fprintln(tw, "ID\tTITLE\tCREATED\tUPDATED") + fmt.Fprintln(tw, "ID\tBRANCH\tTITLE\tCREATED\tUPDATED\tEPHEMERAL") defer tw.Flush() for _, envInfo := range envInfos { - fmt.Fprintf(tw, "%s\t%s\t%s\t%s\n", envInfo.ID, truncate(app, envInfo.State.Title, 40), humanize.Time(envInfo.State.CreatedAt), humanize.Time(envInfo.State.UpdatedAt)) + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%v\n", envInfo.ID, envInfo.State.TrackingBranch, truncate(app, envInfo.State.Title, 40), humanize.Time(envInfo.State.CreatedAt), humanize.Time(envInfo.State.UpdatedAt), envInfo.State.Ephemeral) } return nil }, diff --git a/environment/environment.go b/environment/environment.go index 7c184b9a..da2556f9 100644 --- a/environment/environment.go +++ b/environment/environment.go @@ -31,15 +31,17 @@ type Environment struct { mu sync.RWMutex } -func New(ctx context.Context, dag *dagger.Client, id, title string, config *EnvironmentConfig, initialSourceDir *dagger.Directory) (*Environment, error) { +func New(ctx context.Context, dag *dagger.Client, id, branch, title string, ephemeral bool, config *EnvironmentConfig, initialSourceDir *dagger.Directory) (*Environment, error) { env := &Environment{ EnvironmentInfo: &EnvironmentInfo{ ID: id, State: &State{ - Config: config, - Title: title, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + Config: config, + Title: title, + TrackingBranch: branch, + Ephemeral: ephemeral, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), }, }, dag: dag, diff --git a/environment/filesystem.go b/environment/filesystem.go index 84ead062..96b39010 100644 --- a/environment/filesystem.go +++ b/environment/filesystem.go @@ -2,8 +2,11 @@ package environment import ( "context" + "crypto/sha256" "fmt" "strings" + + godiffpatch "github.com/sourcegraph/go-diff-patch" ) func (env *Environment) FileRead(ctx context.Context, targetFile string, shouldReadEntireFile bool, startLineOneIndexedInclusive int, endLineOneIndexedInclusive int) (string, error) { @@ -45,6 +48,85 @@ func (env *Environment) FileWrite(ctx context.Context, explanation, targetFile, return nil } +func (env *Environment) FileSearchReplace(ctx context.Context, explanation, targetFile, search, replace, matchID string) error { + contents, err := env.container().File(targetFile).Contents(ctx) + if err != nil { + return err + } + + // Find all matches of the search text + matches := []int{} + searchIndex := 0 + for { + index := strings.Index(contents[searchIndex:], search) + if index == -1 { + break + } + actualIndex := searchIndex + index + matches = append(matches, actualIndex) + searchIndex = actualIndex + 1 + } + + if len(matches) == 0 { + return fmt.Errorf("search text not found in file %s", targetFile) + } + + // If there are multiple matches and no matchID is provided, return an error with all matches + if len(matches) > 1 && matchID == "" { + var matchDescriptions []string + for i, matchIndex := range matches { + // Generate a unique ID for each match + id := generateMatchID(targetFile, search, replace, i) + + // Get context around the match (3 lines before and after) + context := getMatchContext(contents, matchIndex, len(search)) + + matchDescriptions = append(matchDescriptions, fmt.Sprintf("Match %d (ID: %s):\n%s", i+1, id, context)) + } + + return fmt.Errorf("multiple matches found for search text in %s. Please specify which_match parameter with one of the following IDs:\n\n%s", + targetFile, strings.Join(matchDescriptions, "\n\n")) + } + + // Determine which match to replace + var targetMatchIndex int + if len(matches) == 1 { + targetMatchIndex = matches[0] + } else { + // Find the match with the specified ID + found := false + for i, matchIndex := range matches { + id := generateMatchID(targetFile, search, replace, i) + if id == matchID { + targetMatchIndex = matchIndex + found = true + break + } + } + if !found { + return fmt.Errorf("match ID %s not found", matchID) + } + } + + // Replace the specific match + newContents := contents[:targetMatchIndex] + replace + contents[targetMatchIndex+len(search):] + + // Apply the changes using `patch` so we don't have to spit out the entire + // contents + return env.ApplyPatch(ctx, godiffpatch.GeneratePatch(targetFile, contents, newContents)) +} + +func (env *Environment) ApplyPatch(ctx context.Context, patch string) error { + ctr := env.container() + err := env.apply(ctx, ctr. + WithDirectory(".", ctr.Directory(".").WithPatch(patch))) + if err != nil { + return fmt.Errorf("failed applying file edit, skipping git propagation: %w", err) + } + env.Notes.Add("Apply patch") + return nil +} + func (env *Environment) FileDelete(ctx context.Context, explanation, targetFile string) error { err := env.apply(ctx, env.container().WithoutFile(targetFile)) if err != nil { @@ -65,3 +147,41 @@ func (env *Environment) FileList(ctx context.Context, path string) (string, erro } return out.String(), nil } + +// generateMatchID creates a unique ID for a match based on file, search, replace, and index +func generateMatchID(targetFile, search, replace string, index int) string { + data := fmt.Sprintf("%s:%s:%s:%d", targetFile, search, replace, index) + hash := sha256.Sum256([]byte(data)) + return fmt.Sprintf("%x", hash)[:8] // Use first 8 characters of hash +} + +// getMatchContext returns the context around a match (3 lines before and after) +func getMatchContext(contents string, matchIndex, matchLength int) string { + lines := strings.Split(contents, "\n") + + // Find which line contains the match + currentPos := 0 + matchLine := 0 + for i, line := range lines { + if currentPos+len(line) >= matchIndex { + matchLine = i + break + } + currentPos += len(line) + 1 // +1 for newline + } + + // Get context lines (3 before, match line, 3 after) + start := max(0, matchLine-3) + end := min(len(lines), matchLine+4) + + contextLines := make([]string, 0, end-start) + for i := start; i < end; i++ { + prefix := " " + if i == matchLine { + prefix = "> " // Mark the line containing the match + } + contextLines = append(contextLines, fmt.Sprintf("%s%s", prefix, lines[i])) + } + + return strings.Join(contextLines, "\n") +} diff --git a/environment/integration/merge_test.go b/environment/integration/merge_test.go index f69347a0..18c90224 100644 --- a/environment/integration/merge_test.go +++ b/environment/integration/merge_test.go @@ -79,7 +79,7 @@ func TestRepositoryApply(t *testing.T) { // Apply the environment (squash merge) var applyOutput bytes.Buffer - err = repo.Apply(ctx, env.ID, &applyOutput) + err = repo.Apply(ctx, env.ID, &applyOutput, true) require.NoError(t, err, "Apply should succeed: %s", applyOutput.String()) // Verify we're still on the initial branch @@ -146,7 +146,7 @@ func TestRepositoryApplyNonExistent(t *testing.T) { // Try to apply non-existent environment var applyOutput bytes.Buffer - err := repo.Apply(ctx, "non-existent-env", &applyOutput) + err := repo.Apply(ctx, "non-existent-env", &applyOutput, true) assert.Error(t, err, "Applying non-existent environment should fail") assert.Contains(t, err.Error(), "not found") }) @@ -203,7 +203,7 @@ func TestRepositoryApplyWithConflicts(t *testing.T) { // Try to apply - this should fail due to conflict var applyOutput bytes.Buffer - err = repo.Apply(ctx, env.ID, &applyOutput) + err = repo.Apply(ctx, env.ID, &applyOutput, true) // The apply should fail due to conflict assert.Error(t, err, "Apply should fail due to conflict") diff --git a/environment/state.go b/environment/state.go index ed7e2d4a..a0a47152 100644 --- a/environment/state.go +++ b/environment/state.go @@ -10,9 +10,11 @@ type State struct { CreatedAt time.Time `json:"created_at,omitempty"` UpdatedAt time.Time `json:"updated_at,omitempty"` - Config *EnvironmentConfig `json:"config,omitempty"` - Container string `json:"container,omitempty"` - Title string `json:"title,omitempty"` + Config *EnvironmentConfig `json:"config,omitempty"` + Container string `json:"container,omitempty"` + Title string `json:"title,omitempty"` + TrackingBranch string `json:"tracking_branch,omitempty"` + Ephemeral bool `json:"ephemeral,omitempty"` } func (s *State) Marshal() ([]byte, error) { diff --git a/go.mod b/go.mod index 461f6838..4a42d9c9 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.24.3 toolchain go1.24.4 require ( - dagger.io/dagger v0.18.12 + dagger.io/dagger v0.18.14 github.com/charmbracelet/bubbletea v1.3.5 github.com/charmbracelet/fang v0.3.0 github.com/charmbracelet/lipgloss v1.1.0 @@ -14,6 +14,7 @@ require ( github.com/mark3labs/mcp-go v0.29.0 github.com/mitchellh/go-homedir v1.1.0 github.com/pelletier/go-toml/v2 v2.2.4 + github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 github.com/tiborvass/go-watch v0.0.0-20250607214558-08999a83bf8b diff --git a/go.sum b/go.sum index 44d4f368..021f5c84 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ -dagger.io/dagger v0.18.11 h1:6lSfemlbGM2HmdOjhgevrX2+orMDGKU/xTaBMZ+otyY= -dagger.io/dagger v0.18.11/go.mod h1:azlZ24m2br95t0jQHUBpL5SiafeqtVDLl1Itlq6GO+4= -dagger.io/dagger v0.18.12 h1:s7v8aHlzDUogZ/jW92lHC+gljCNRML+0mosfh13R4vs= -dagger.io/dagger v0.18.12/go.mod h1:azlZ24m2br95t0jQHUBpL5SiafeqtVDLl1Itlq6GO+4= +dagger.io/dagger v0.18.14 h1:7+VFqNJffm6Qa8ckNRMfsM64sI5dXbRnZswCQ1jnDF0= +dagger.io/dagger v0.18.14/go.mod h1:azlZ24m2br95t0jQHUBpL5SiafeqtVDLl1Itlq6GO+4= github.com/99designs/gqlgen v0.17.75 h1:GwHJsptXWLHeY7JO8b7YueUI4w9Pom6wJTICosDtQuI= github.com/99designs/gqlgen v0.17.75/go.mod h1:p7gbTpdnHyl70hmSpM8XG8GiKwmCv+T5zkdY8U8bLog= github.com/Khan/genqlient v0.8.1 h1:wtOCc8N9rNynRLXN3k3CnfzheCUNKBcvXmVv5zt6WCs= @@ -112,6 +110,8 @@ github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sosodev/duration v1.3.1 h1:qtHBDMQ6lvMQsL15g4aopM4HEfOaYuhWBw3NPTtlqq4= github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERAikUR6SDg= +github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e h1:H+jDTUeF+SVd4ApwnSFoew8ZwGNRfgb9EsZc7LcocAg= +github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e/go.mod h1:VsUklG6OQo7Ctunu0gS3AtEOCEc2kMB6r5rKzxAes58= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= diff --git a/mcpserver/tools.go b/mcpserver/tools.go index 1433ab24..8cf0d89c 100644 --- a/mcpserver/tools.go +++ b/mcpserver/tools.go @@ -137,11 +137,14 @@ func init() { EnvironmentFileReadTool, EnvironmentFileListTool, EnvironmentFileWriteTool, + EnvironmentFilePatchTool, EnvironmentFileDeleteTool, EnvironmentAddServiceTool, EnvironmentCheckpointTool, + + EnvironmentSyncFromUserTool, ) } @@ -231,6 +234,13 @@ Environment configuration is managed by the user via cu config commands.`, mcp.Description("Short description of the work that is happening in this environment."), mcp.Required(), ), + mcp.WithBoolean("ephemeral", + mcp.Description("Whether this environment is for a sub-task of a larger task."), + mcp.Required(), + ), + mcp.WithString("background_branch", + mcp.Description("A user-supplied branch name to create and track, instead of the current branch."), + ), ), Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { repo, err := openRepository(ctx, request) @@ -241,17 +251,31 @@ Environment configuration is managed by the user via cu config commands.`, if err != nil { return nil, err } + branch := request.GetString("background_branch", "") + ephemeral := request.GetBool("ephemeral", false) + if branch == "" && !ephemeral { + branch, err = repo.CurrentUserBranch(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get current branch: %w", err) + } + } dag, ok := ctx.Value(daggerClientKey{}).(*dagger.Client) if !ok { return nil, fmt.Errorf("dagger client not found in context") } - env, err := repo.Create(ctx, dag, title, request.GetString("explanation", "")) + env, err := repo.Create(ctx, dag, branch, title, request.GetString("explanation", ""), ephemeral) if err != nil { return nil, fmt.Errorf("failed to create environment: %w", err) } + if !ephemeral { + if err := repo.TrackEnvironment(ctx, branch, env.ID); err != nil { + return nil, fmt.Errorf("unable to set current branch tracking environment: %w", err) + } + } + out, err := marshalEnvironment(env) if err != nil { return nil, fmt.Errorf("failed to marshal environment: %w", err) @@ -613,6 +637,73 @@ var EnvironmentFileWriteTool = &Tool{ }, } +var EnvironmentFilePatchTool = &Tool{ + Definition: mcp.NewTool("environment_file_patch", + mcp.WithDescription("Find and replace text in a file."), + mcp.WithString("explanation", + mcp.Description("One sentence explanation for why this file is being edited."), + ), + mcp.WithString("environment_source", + mcp.Description("Absolute path to the source git repository for the environment."), + mcp.Required(), + ), + mcp.WithString("environment_id", + mcp.Description("The ID of the environment for this command. Must call `environment_create` first."), + mcp.Required(), + ), + mcp.WithString("target_file", + mcp.Description("Path of the file to write, absolute or relative to the workdir."), + mcp.Required(), + ), + mcp.WithString("search_text", + mcp.Description("The text to find and replace."), + mcp.Required(), + ), + mcp.WithString("replace_text", + mcp.Description("The text to insert."), + mcp.Required(), + ), + mcp.WithString("which_match", + mcp.Description("The ID of the match to replace, if there were multiple matches."), + ), + ), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + repo, env, err := openEnvironment(ctx, request) + if err != nil { + return mcp.NewToolResultErrorFromErr("unable to open the environment", err), nil + } + + targetFile, err := request.RequireString("target_file") + if err != nil { + return nil, err + } + search, err := request.RequireString("search_text") + if err != nil { + return nil, err + } + replace, err := request.RequireString("replace_text") + if err != nil { + return nil, err + } + + if err := env.FileSearchReplace(ctx, + request.GetString("explanation", ""), + targetFile, + search, + replace, + request.GetString("which_match", ""), + ); err != nil { + return mcp.NewToolResultErrorFromErr("failed to write file", err), nil + } + + if err := repo.Update(ctx, env, request.GetString("explanation", "")); err != nil { + return mcp.NewToolResultErrorFromErr("unable to update the environment", err), nil + } + + return mcp.NewToolResultText(fmt.Sprintf("file %s edited successfully and committed to container-use/ remote", targetFile)), nil + }, +} + var EnvironmentFileDeleteTool = &Tool{ Definition: newEnvironmentTool( "environment_file_delete", @@ -756,3 +847,50 @@ Supported schemas are: return mcp.NewToolResultText(fmt.Sprintf("Service added and started successfully: %s", string(output))), nil }, } + +var EnvironmentSyncFromUserTool = &Tool{ + Definition: newEnvironmentTool( + "environment_sync_from_user", + "Apply the user's unstaged changes to the environment. ONLY RUN WHEN EXPLICITLY REQUESTED BY THE USER.", + ), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + repo, env, err := openEnvironment(ctx, request) + if err != nil { + return nil, err + } + + currentBranch, err := repo.CurrentUserBranch(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get current branch: %w", err) + } + branchEnv, err := repo.TrackedEnvironment(ctx, currentBranch) + if err != nil { + return nil, err + } + if branchEnv != env.ID { + return nil, fmt.Errorf("branch is tracking %s, not %s", branchEnv, env.ID) + } + + patch, err := repo.DiffUserLocalChanges(ctx) + if err != nil { + return nil, fmt.Errorf("failed to generate patch: %w", err) + } + if len(patch) == 0 { + return mcp.NewToolResultText("No unstaged changes to pull."), nil + } + + if err := env.ApplyPatch(ctx, patch); err != nil { + return nil, fmt.Errorf("failed to pull changes to environment: %w", err) + } + + if err := repo.ResetUserLocalChanges(ctx); err != nil { + return nil, fmt.Errorf("unable to reset user's worktree: %w", err) + } + + if err := repo.Update(ctx, env, request.GetString("explanation", "")); err != nil { + return nil, fmt.Errorf("unable to update the environment: %w", err) + } + + return mcp.NewToolResultText("Patch applied successfully to the environment:\n\n```patch\n" + string(patch) + "\n```"), nil + }, +} diff --git a/repository/git.go b/repository/git.go index 5bae7ced..1f9695ab 100644 --- a/repository/git.go +++ b/repository/git.go @@ -282,16 +282,25 @@ func (r *Repository) addGitNote(ctx context.Context, env *environment.Environmen return r.propagateGitNotes(ctx, gitNotesLogRef) } -func (r *Repository) currentUserBranch(ctx context.Context) (string, error) { - return RunGitCommand(ctx, r.userRepoPath, "branch", "--show-current") +func (r *Repository) CurrentUserBranch(ctx context.Context) (string, error) { + currentBranch, err := RunGitCommand(ctx, r.userRepoPath, "branch", "--show-current") + if err != nil { + return "", err + } + // TODO(vito): pretty sure this is redundant, but consolidating from other + // places + branch := strings.TrimSpace(currentBranch) + if branch == "" { + return "", fmt.Errorf("no current branch (detached HEAD?)") + } + return branch, nil } func (r *Repository) mergeBase(ctx context.Context, env *environment.EnvironmentInfo) (string, error) { - currentBranch, err := r.currentUserBranch(ctx) + currentBranch, err := r.CurrentUserBranch(ctx) if err != nil { return "", err } - currentBranch = strings.TrimSpace(currentBranch) if currentBranch == "" { currentBranch = "HEAD" } diff --git a/repository/repository.go b/repository/repository.go index 81803f53..90c9bc94 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -144,7 +144,7 @@ func (r *Repository) exists(ctx context.Context, id string) error { // Create creates a new environment with the given description and explanation. // Requires a dagger client for container operations during environment initialization. -func (r *Repository) Create(ctx context.Context, dag *dagger.Client, description, explanation string) (*environment.Environment, error) { +func (r *Repository) Create(ctx context.Context, dag *dagger.Client, branch, description, explanation string, ephemeral bool) (*environment.Environment, error) { id := petname.Generate(2, "-") worktree, err := r.initializeWorktree(ctx, id) if err != nil { @@ -173,7 +173,7 @@ func (r *Repository) Create(ctx context.Context, dag *dagger.Client, description return nil, err } - env, err := environment.New(ctx, dag, id, description, config, baseSourceDir) + env, err := environment.New(ctx, dag, id, branch, description, ephemeral, config, baseSourceDir) if err != nil { return nil, err } @@ -283,6 +283,23 @@ func (r *Repository) Update(ctx context.Context, env *environment.Environment, e if err := r.propagateToWorktree(ctx, env, explanation); err != nil { return err } + + // Check if branch tracking is enabled and we're on the tracked branch + if env.State.TrackingBranch != "" { + currentBranch, err := r.CurrentUserBranch(ctx) + if err != nil { + return fmt.Errorf("failed to check current branch for tracking: %w", err) + } else { + if currentBranch == env.State.TrackingBranch { + // Apply environment changes to the user's working tree + var logs strings.Builder + if err := r.Apply(ctx, env.ID, &logs, false); err != nil { + return fmt.Errorf("failed to apply tracking changes to working tree: %w\n\nlogs:\n%s\n", err, logs.String()) + } + } + } + } + if note := env.Notes.Pop(); note != "" { return r.addGitNote(ctx, env, note) } @@ -414,11 +431,109 @@ func (r *Repository) Merge(ctx context.Context, id string, w io.Writer) error { return RunInteractiveGitCommand(ctx, r.userRepoPath, w, "merge", "--no-ff", "--autostash", "-m", "Merge environment "+envInfo.ID, "--", "container-use/"+envInfo.ID) } -func (r *Repository) Apply(ctx context.Context, id string, w io.Writer) error { +func (r *Repository) Apply(ctx context.Context, id string, w io.Writer, discard bool) (rerr error) { envInfo, err := r.Info(ctx, id) if err != nil { return err } - return RunInteractiveGitCommand(ctx, r.userRepoPath, w, "merge", "--autostash", "--squash", "--", "container-use/"+envInfo.ID) + diffOutput, err := r.DiffUserLocalChanges(ctx) + if err != nil { + return fmt.Errorf("failed to check for unstaged changes: %w", err) + } + + hasUnstagedChanges := len(diffOutput) > 0 + + fmt.Fprintf(w, "Creating virtual stash as backup...\n") + stashID, err := RunGitCommand(ctx, r.userRepoPath, "stash", "create") + if err != nil { + return fmt.Errorf("failed to stash changes: %w", err) + } + defer func() { + if rerr != nil { + fmt.Fprintf(w, "ERROR: %s\n", rerr) + fmt.Fprintf(w, "Your prior changes can be restored with `git stash apply %s`\n", stashID) + } + }() + + // Reset to clean state + if err := RunInteractiveGitCommand(ctx, r.userRepoPath, w, "reset", "--hard", "HEAD"); err != nil { + return fmt.Errorf("failed to reset: %w", err) + } + + // Apply the merge without autostash + fmt.Fprintf(w, "Applying environment changes...\n") + if err := RunInteractiveGitCommand(ctx, r.userRepoPath, w, "merge", "--squash", "--", "container-use/"+envInfo.ID); err != nil { + return fmt.Errorf("failed to merge: %w", err) + } + + // Apply user changes back + if hasUnstagedChanges && !discard { + fmt.Fprintf(w, "Restoring user changes...\n") + + // 1. Temporarily commit the agent's changes + if err := RunInteractiveGitCommand(ctx, r.userRepoPath, w, "commit", "-m", "temp: agent changes"); err != nil { + return fmt.Errorf("failed to commit agent changes: %w", err) + } + + // 2. Apply the user's patch + applyCmd := exec.CommandContext(ctx, "git", "apply", "-") + applyCmd.Dir = r.userRepoPath + applyCmd.Stdin = strings.NewReader(diffOutput) + applyCmd.Stdout = w + applyCmd.Stderr = w + if err := applyCmd.Run(); err != nil { + return fmt.Errorf("failed to apply user changes: %w", err) + } + + // 3. Reset to unstage the user's changes + if err := RunInteractiveGitCommand(ctx, r.userRepoPath, w, "reset"); err != nil { + return fmt.Errorf("failed to reset user changes: %w", err) + } + + // 4. Soft reset to bring agent changes back to staging + if err := RunInteractiveGitCommand(ctx, r.userRepoPath, w, "reset", "--soft", "HEAD~1"); err != nil { + return fmt.Errorf("failed to restore agent changes to staging: %w", err) + } + + // Clean up patch file on successful application + fmt.Fprintf(w, "User changes successfully restored as unstaged changes.\n") + } + + return nil +} + +func (r *Repository) DiffUserLocalChanges(ctx context.Context) (string, error) { + diff, err := RunGitCommand(ctx, r.userRepoPath, "diff", "--binary") + if err != nil { + return "", fmt.Errorf("failed to get user diff: %w", err) + } + return diff, nil +} + +func (r *Repository) TrackEnvironment(ctx context.Context, branch, envID string) error { + _, err := RunGitCommand(ctx, r.userRepoPath, "config", "branch."+branch+".environment", envID) + if err != nil { + return fmt.Errorf("failed to set branch tracking env: %w", err) + } + return nil +} + +func (r *Repository) TrackedEnvironment(ctx context.Context, branch string) (string, error) { + envID, err := RunGitCommand(ctx, r.userRepoPath, "config", "get", "--default=", "branch."+branch+".environment") + if err != nil { + return "", fmt.Errorf("failed to get branch tracking env: %w", err) + } + envID = strings.TrimSpace(envID) + if envID == "" { + return "", fmt.Errorf("branch %s is not tracking an environment", branch) + } + return envID, nil +} + +func (r *Repository) ResetUserLocalChanges(ctx context.Context) error { + if _, err := RunGitCommand(ctx, r.userRepoPath, "restore", "."); err != nil { + return fmt.Errorf("failed to reset unstaged changes: %w", err) + } + return nil }