Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def main(

while shell and interaction:
option = typer.prompt(
text="[E]xecute, [D]escribe, [A]bort",
type=Choice(("e", "d", "a", "y"), case_sensitive=False),
text="[E]xecute, [D]escribe, [M]odify, [R]un (for Modify & Execute), [A]bort",
type=Choice(("e", "d", "a", "y", "m", "r"), case_sensitive=False),
default="e" if cfg.get("DEFAULT_EXECUTE_SHELL_CMD") == "true" else "a",
show_choices=False,
show_default=False,
Expand All @@ -258,6 +258,14 @@ def main(
functions=function_schemas,
)
continue
elif option == "m" or option == "r":
# modifying full_completion allows iterative manipulation.
full_completion = get_edited_prompt(full_completion)
DefaultHandler(role_class, md).printer(full_completion, False)
if option == "r":
run_command(full_completion)
break
continue
break


Expand Down
5 changes: 4 additions & 1 deletion sgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sgpt.integration import bash_integration, zsh_integration


def get_edited_prompt() -> str:
def get_edited_prompt(prompt: str = "") -> str:
"""
Opens the user's default editor to let them
input a prompt, and returns the edited text.
Expand All @@ -21,6 +21,9 @@ def get_edited_prompt() -> str:
with NamedTemporaryFile(suffix=".txt", delete=False) as file:
# Create file and store path.
file_path = file.name
if prompt:
file.write(prompt.encode("utf-8"))
file.close()
editor = os.environ.get("EDITOR", "vim")
# This will write text to file using $EDITOR.
os.system(f"{editor} {file_path}")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_shell(completion):
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
assert result.exit_code == 0
assert "git commit" in result.stdout
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout
assert "[E]xecute, [D]escribe, [M]odify, [R]un (for Modify & Execute), [A]bort:" in result.stdout


@patch("sgpt.printer.TextPrinter.live_print")
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_shell_stdin(completion):
completion.assert_called_once_with(**comp_args(role, expected_prompt))
assert result.exit_code == 0
assert "ls -l | sort" in result.stdout
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout
assert "[E]xecute, [D]escribe, [M]odify, [R]un (for Modify & Execute), [A]bort:" in result.stdout


@patch("sgpt.handlers.handler.completion")
Expand Down