Skip to content

Commit 4874f6a

Browse files
Extract code parsing to separate method (#4)
1 parent ac0cc8a commit 4874f6a

File tree

2 files changed

+236
-187
lines changed

2 files changed

+236
-187
lines changed

src/mcp_django_shell/shell.py

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datetime import datetime
99
from io import StringIO
1010
from typing import Any
11+
from typing import Literal
1112

1213
from asgiref.sync import sync_to_async
1314

@@ -50,56 +51,37 @@ def _execute(self, code: str, timeout: int | None = None) -> Result:
5051
Note: This synchronous method contains the actual execution logic.
5152
Use `execute()` for async contexts or `_execute()` for sync/testing.
5253
"""
53-
54-
def can_eval(code: str) -> bool:
55-
try:
56-
compile(code, "<stdin>", "eval")
57-
return True
58-
except SyntaxError:
59-
return False
60-
6154
stdout = StringIO()
6255
stderr = StringIO()
6356

6457
with redirect_stdout(stdout), redirect_stderr(stderr):
6558
try:
66-
# Try as single expression
67-
if can_eval(code):
68-
payload = eval(code, self.globals)
69-
return self.save_result(
70-
ExpressionResult(
71-
code=code,
72-
value=payload,
73-
stdout=stdout.getvalue(),
74-
stderr=stderr.getvalue(),
59+
code, setup, code_type = parse_code(code)
60+
61+
# Execute setup, if any (only applicable to expressions)
62+
if setup:
63+
exec(setup, self.globals)
64+
65+
match code_type:
66+
case "expression":
67+
value = eval(code, self.globals)
68+
return self.save_result(
69+
ExpressionResult(
70+
code=code,
71+
value=value,
72+
stdout=stdout.getvalue(),
73+
stderr=stderr.getvalue(),
74+
)
7575
)
76-
)
77-
78-
# Check for multi-line with final expression
79-
lines = code.strip().splitlines()
80-
last_line = lines[-1] if lines else ""
81-
82-
if can_eval(last_line):
83-
# Execute setup lines, eval last line
84-
if len(lines) > 1:
85-
exec("\n".join(lines[:-1]), self.globals)
86-
payload = eval(last_line, self.globals)
87-
return self.save_result(
88-
ExpressionResult(
89-
code=code,
90-
value=payload,
91-
stdout=stdout.getvalue(),
92-
stderr=stderr.getvalue(),
76+
case "statement":
77+
exec(code, self.globals)
78+
return self.save_result(
79+
StatementResult(
80+
code=code,
81+
stdout=stdout.getvalue(),
82+
stderr=stderr.getvalue(),
83+
)
9384
)
94-
)
95-
96-
# Execute as pure statements
97-
exec(code, self.globals)
98-
return self.save_result(
99-
StatementResult(
100-
code=code, stdout=stdout.getvalue(), stderr=stderr.getvalue()
101-
)
102-
)
10385

10486
except Exception as e:
10587
return self.save_result(
@@ -116,6 +98,36 @@ def save_result(self, result: Result) -> Result:
11698
return result
11799

118100

101+
def parse_code(code: str) -> tuple[str, str, Literal["expression", "statement"]]:
102+
"""Determine how code should be executed.
103+
104+
Returns:
105+
A tuple (main_code, setup_code, code_type), where:
106+
- main_code: The code to evaluate (expression) or execute (statement)
107+
- setup_code: Lines to execute before evaluating expressions (empty for statements)
108+
- code_type: "expression" or "statement"
109+
"""
110+
111+
def can_eval(code: str) -> bool:
112+
try:
113+
compile(code, "<stdin>", "eval")
114+
return True
115+
except SyntaxError:
116+
return False
117+
118+
if can_eval(code):
119+
return code, "", "expression"
120+
121+
lines = code.strip().splitlines()
122+
last_line = lines[-1] if lines else ""
123+
124+
if can_eval(last_line):
125+
setup_code = "\n".join(lines[:-1])
126+
return last_line, setup_code, "expression"
127+
128+
return code, "", "statement"
129+
130+
119131
@dataclass
120132
class ExpressionResult:
121133
code: str

0 commit comments

Comments
 (0)