Skip to content

Commit f38e55b

Browse files
committed
feat: Extract command from backticks
1 parent aac2f54 commit f38e55b

File tree

3 files changed

+105
-2
lines changed

3 files changed

+105
-2
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ python -m venv env && source ./env/bin/activate
2020
Install the necessary dependencies, including development and test dependencies:
2121

2222
```shell
23-
pip install -e ."[dev,test]"
23+
pip install -e ."[dev,test,litellm]"
2424
```
2525

2626
### Start Coding
@@ -35,4 +35,4 @@ Before creating a pull request, run `scripts/lint.sh` and `scripts/tests.sh` to
3535
### Code Review
3636
After submitting your pull request, be patient and receptive to feedback from reviewers. Address any concerns they raise and collaborate to refine the code. Together, we can enhance the ShellGPT project.
3737

38-
Thank you once again for your contribution! We're excited to have you join us.
38+
Thank you once again for your contribution! We're excited to have you join us.

sgpt/handlers/handler.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from pathlib import Path
34
from typing import Any, Callable, Dict, Generator, List, Optional
45

@@ -37,6 +38,7 @@ class Handler:
3738

3839
def __init__(self, role: SystemRole, markdown: bool) -> None:
3940
self.role = role
41+
self.is_shell = role.name == DefaultRoles.SHELL.value
4042

4143
api_base_url = cfg.get("API_BASE_URL")
4244
self.base_url = None if api_base_url == "default" else api_base_url
@@ -45,6 +47,13 @@ def __init__(self, role: SystemRole, markdown: bool) -> None:
4547
self.markdown = "APPLY MARKDOWN" in self.role.role and markdown
4648
self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")
4749

50+
self.backticks_start = re.compile(r"(^|[\r\n]+)```\w*[\r\n]+")
51+
end_regex_parts = [r"[\r\n]+", "`", "`", "`", r"([\r\n]+|$)"]
52+
self.backticks_end_prefixes = [
53+
re.compile("".join(end_regex_parts[: i + 1]))
54+
for i in range(len(end_regex_parts))
55+
]
56+
4857
@property
4958
def printer(self) -> Printer:
5059
return (
@@ -82,6 +91,48 @@ def handle_function_call(
8291
yield f"```text\n{result}\n```\n"
8392
messages.append({"role": "function", "content": result, "name": name})
8493

94+
def _matches_end_at(self, text: str) -> tuple[bool, int]:
95+
end_of_match = 0
96+
for _i, regex in enumerate(self.backticks_end_prefixes):
97+
m = regex.search(text)
98+
if m:
99+
end_of_match = m.end()
100+
else:
101+
return False, end_of_match
102+
return True, m.start()
103+
104+
def _filter_chunks(
105+
self, chunks: Generator[str, None, None]
106+
) -> Generator[str, None, None]:
107+
buffer = ""
108+
inside_backticks = False
109+
end_of_beginning = 0
110+
111+
for chunk in chunks:
112+
buffer += chunk
113+
if not inside_backticks:
114+
m = self.backticks_start.search(buffer)
115+
if not m:
116+
continue
117+
new_end_of_beginning = m.end()
118+
if new_end_of_beginning > end_of_beginning:
119+
end_of_beginning = new_end_of_beginning
120+
continue
121+
inside_backticks = True
122+
buffer = buffer[end_of_beginning:]
123+
if inside_backticks:
124+
matches_end, index = self._matches_end_at(buffer)
125+
if matches_end:
126+
yield buffer[:index]
127+
return
128+
if index == len(buffer):
129+
continue
130+
else:
131+
yield buffer
132+
buffer = ""
133+
if buffer:
134+
yield buffer
135+
85136
@cache
86137
def get_completion(
87138
self,
@@ -163,4 +214,6 @@ def handle(
163214
caching=caching,
164215
**kwargs,
165216
)
217+
if self.role.name == DefaultRoles.SHELL.value:
218+
generator = self._filter_chunks(generator)
166219
return self.printer(generator, not disable_stream)

tests/test_shell.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from pathlib import Path
33
from unittest.mock import patch
44

5+
import pytest
6+
57
from sgpt.config import cfg
68
from sgpt.role import DefaultRoles, SystemRole
79

@@ -22,6 +24,54 @@ def test_shell(completion):
2224
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout
2325

2426

27+
@patch("sgpt.handlers.handler.completion")
28+
@pytest.mark.parametrize(
29+
"prefix,suffix",
30+
[
31+
("", ""),
32+
("some text before\n```powershell\n", "\n```" ""),
33+
("```powershell\n", "\n```\nsome text after" ""),
34+
("some text before\n```powershell\n", "\n```\nsome text after" ""),
35+
(
36+
"some text with ``` before\n```powershell\n",
37+
"\n```\nsome text with ``` after" "",
38+
),
39+
("```powershell\n", "\n```" ""),
40+
("```\n", "\n```" ""),
41+
("```powershell\r\n", "\r\n```" ""),
42+
("```\r\n", "\r\n```" ""),
43+
("```powershell\r", "\r```" ""),
44+
("```\r", "\r```" ""),
45+
],
46+
)
47+
@pytest.mark.parametrize("group_by_size", range(10))
48+
def test_shell_no_backticks(completion, prefix: str, suffix: str, group_by_size: int):
49+
expected_output = "Get-Process | \nWhere-Object { $_.Port -eq 9000 }\r\n | Select-Object Id | Text \r\nwith '```' inside"
50+
produced_output = prefix + expected_output + suffix
51+
if group_by_size == 0:
52+
produced_tokens = list(produced_output)
53+
else:
54+
produced_tokens = [
55+
produced_output[i : i + group_by_size]
56+
for i in range(0, len(produced_output), group_by_size)
57+
]
58+
assert produced_output == "".join(produced_tokens)
59+
60+
role = SystemRole.get(DefaultRoles.SHELL.value)
61+
completion.return_value = mock_comp(produced_tokens)
62+
63+
args = {"prompt": "find pid by port 9000", "--shell": True}
64+
result = runner.invoke(app, cmd_args(**args))
65+
66+
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
67+
index = result.stdout.find(expected_output)
68+
assert index >= 0
69+
rest = result.stdout[index + len(expected_output) :].strip()
70+
assert "`" not in rest
71+
assert result.exit_code == 0
72+
assert "[E]xecute, [D]escribe, [A]bort:" == rest
73+
74+
2575
@patch("sgpt.printer.TextPrinter.live_print")
2676
@patch("sgpt.printer.MarkdownPrinter.live_print")
2777
@patch("sgpt.handlers.handler.completion")

0 commit comments

Comments
 (0)