Skip to content

Commit 6f53284

Browse files
Add imports parameter to django_shell tool and refactor code parsing (#27)
1 parent e4c2989 commit 6f53284

File tree

7 files changed

+355
-159
lines changed

7 files changed

+355
-159
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ and this project attempts to adhere to [Semantic Versioning](https://semver.org/
1818

1919
## [Unreleased]
2020

21+
### Added
22+
23+
- Optional `imports` parameter to `django_shell` tool for providing import statements separately from main code execution.
24+
2125
## [0.7.0]
2226

2327
### Added

src/mcp_django_shell/code.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import Any
5+
from typing import Literal
6+
7+
8+
def filter_existing_imports(imports: str, globals_dict: dict[str, Any]) -> str:
9+
tree = ast.parse(imports)
10+
11+
if not all(isinstance(stmt, ast.Import | ast.ImportFrom) for stmt in tree.body):
12+
raise ValueError("Input must contain only import statements")
13+
14+
filtered_statements: list[tuple[ast.Import | ast.ImportFrom, list[ast.alias]]] = []
15+
16+
for stmt in tree.body:
17+
if isinstance(stmt, ast.Import):
18+
needed = [
19+
alias
20+
for alias in stmt.names
21+
if (alias.asname or alias.name.split(".")[0]) not in globals_dict
22+
]
23+
if needed:
24+
filtered_statements.append((stmt, needed))
25+
26+
elif isinstance(stmt, ast.ImportFrom):
27+
if stmt.names[0].name == "*":
28+
# Star imports - always include (can't easily determine what's imported)
29+
filtered_statements.append((stmt, stmt.names))
30+
else:
31+
needed = [
32+
alias
33+
for alias in stmt.names
34+
if (alias.asname or alias.name) not in globals_dict
35+
]
36+
if needed:
37+
filtered_statements.append((stmt, needed))
38+
39+
filtered_lines: list[str] = []
40+
41+
for stmt, names in filtered_statements:
42+
import_parts: list[str] = []
43+
44+
if isinstance(stmt, ast.ImportFrom):
45+
module = stmt.module or ""
46+
level = "." * stmt.level if stmt.level else ""
47+
import_parts.append(f"from {level}{module}")
48+
49+
import_parts.append("import")
50+
51+
name_parts = [
52+
f"{alias.name} as {alias.asname}" if alias.asname else alias.name
53+
for alias in names
54+
]
55+
import_parts.append(", ".join(name_parts))
56+
57+
filtered_lines.append(" ".join(import_parts))
58+
59+
return "\n".join(filtered_lines)
60+
61+
62+
def parse_code(code: str) -> tuple[str, str, Literal["expression", "statement"]]:
63+
"""Determine how code should be executed.
64+
65+
Returns:
66+
A tuple (main_code, setup_code, code_type), where:
67+
- main_code: The code to evaluate (expression) or execute (statement)
68+
- setup_code: Lines to execute before evaluating expressions (empty for statements)
69+
- code_type: "expression" or "statement"
70+
"""
71+
72+
def can_eval(code: str) -> bool:
73+
try:
74+
compile(code, "<stdin>", "eval")
75+
return True
76+
except SyntaxError:
77+
return False
78+
79+
if can_eval(code):
80+
return code, "", "expression"
81+
82+
lines = code.strip().splitlines()
83+
last_line = lines[-1] if lines else ""
84+
85+
if can_eval(last_line):
86+
setup_code = "\n".join(lines[:-1])
87+
return last_line, setup_code, "expression"
88+
89+
return code, "", "statement"

src/mcp_django_shell/server.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from fastmcp import FastMCP
99
from mcp.types import ToolAnnotations
1010

11+
from .code import filter_existing_imports
12+
from .code import parse_code
1113
from .output import DjangoShellOutput
1214
from .output import ErrorOutput
1315
from .resources import AppResource
@@ -55,8 +57,12 @@
5557
),
5658
)
5759
async def django_shell(
58-
code: Annotated[str, "Python code to be executed inside the Django shell session"],
5960
ctx: Context,
61+
code: Annotated[str, "Python code to be executed inside the Django shell session"],
62+
imports: Annotated[
63+
str | None,
64+
"Optional import statements to execute before running the main code. Should contain all necessary imports for the code to run successfully, such as 'from django.contrib.auth.models import User\\nfrom myapp.models import MyModel'",
65+
] = None,
6066
) -> DjangoShellOutput:
6167
"""Execute Python code in a stateful Django shell session.
6268
@@ -72,17 +78,28 @@ async def django_shell(
7278
`.filter()` and `.get()` rather than their async counterparts (`.afilter()`, `.aget()`).
7379
"""
7480
logger.info(
75-
"django_shell tool called - request_id: %s, client_id: %s, code: %s",
81+
"django_shell tool called - request_id: %s, client_id: %s, code: %s, imports: %s",
7682
ctx.request_id,
7783
ctx.client_id or "unknown",
7884
(code[:100] + "..." if len(code) > 100 else code).replace("\n", "\\n"),
85+
(imports[:50] + "..." if imports and len(imports) > 50 else imports or "None"),
7986
)
8087
logger.debug(
8188
"Full code for django_shell - request_id: %s: %s", ctx.request_id, code
8289
)
90+
if imports:
91+
logger.debug(
92+
"Imports for django_shell - request_id: %s: %s", ctx.request_id, imports
93+
)
94+
95+
filtered_imports = filter_existing_imports(imports, shell.globals)
96+
if filtered_imports.strip():
97+
code = f"{filtered_imports}\n{code}"
98+
99+
parsed_code, setup, code_type = parse_code(code)
83100

84101
try:
85-
result = await shell.execute(code)
102+
result = await shell.execute(parsed_code, setup, code_type)
86103
output = DjangoShellOutput.from_result(result)
87104

88105
logger.debug(

src/mcp_django_shell/shell.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def reset(self):
4242
self.globals = {}
4343
self.history = []
4444

45-
async def execute(self, code: str) -> Result:
45+
async def execute(
46+
self, code: str, setup: str, code_type: Literal["expression", "statement"]
47+
) -> Result:
4648
"""Execute Python code in the Django shell context (async wrapper).
4749
4850
This async wrapper enables use from FastMCP and other async contexts.
@@ -54,9 +56,11 @@ async def execute(self, code: str) -> Result:
5456
errors.
5557
"""
5658

57-
return await sync_to_async(self._execute)(code)
59+
return await sync_to_async(self._execute)(code, setup, code_type)
5860

59-
def _execute(self, code: str) -> Result:
61+
def _execute(
62+
self, code: str, setup: str, code_type: Literal["expression", "statement"]
63+
) -> Result:
6064
"""Execute Python code in the Django shell context (synchronous).
6165
6266
Attempts to evaluate code as an expression first (returning a value),
@@ -65,6 +69,7 @@ def _execute(self, code: str) -> Result:
6569
Note: This synchronous method contains the actual execution logic.
6670
Use `execute()` for async contexts or `_execute()` for sync/testing.
6771
"""
72+
6873
code_preview = (code[:100] + "..." if len(code) > 100 else code).replace(
6974
"\n", "\\n"
7075
)
@@ -75,8 +80,6 @@ def _execute(self, code: str) -> Result:
7580

7681
with redirect_stdout(stdout), redirect_stderr(stderr):
7782
try:
78-
code, setup, code_type = parse_code(code)
79-
8083
logger.debug(
8184
"Execution type: %s, has setup: %s", code_type, bool(setup)
8285
)
@@ -85,7 +88,6 @@ def _execute(self, code: str) -> Result:
8588
code[:200] + "..." if len(code) > 200 else code,
8689
)
8790

88-
# Execute setup, if any (only applicable to expressions)
8991
if setup:
9092
logger.debug(
9193
"Setup code: %s",
@@ -146,36 +148,6 @@ def save_result(self, result: Result) -> Result:
146148
return result
147149

148150

149-
def parse_code(code: str) -> tuple[str, str, Literal["expression", "statement"]]:
150-
"""Determine how code should be executed.
151-
152-
Returns:
153-
A tuple (main_code, setup_code, code_type), where:
154-
- main_code: The code to evaluate (expression) or execute (statement)
155-
- setup_code: Lines to execute before evaluating expressions (empty for statements)
156-
- code_type: "expression" or "statement"
157-
"""
158-
159-
def can_eval(code: str) -> bool:
160-
try:
161-
compile(code, "<stdin>", "eval")
162-
return True
163-
except SyntaxError:
164-
return False
165-
166-
if can_eval(code):
167-
return code, "", "expression"
168-
169-
lines = code.strip().splitlines()
170-
last_line = lines[-1] if lines else ""
171-
172-
if can_eval(last_line):
173-
setup_code = "\n".join(lines[:-1])
174-
return last_line, setup_code, "expression"
175-
176-
return code, "", "statement"
177-
178-
179151
@dataclass
180152
class ExpressionResult:
181153
code: str

tests/test_code.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from mcp_django_shell.code import filter_existing_imports
6+
from mcp_django_shell.code import parse_code
7+
8+
9+
def test_parse_code_single_expression():
10+
code, setup, code_type = parse_code("2 + 2")
11+
12+
assert code == "2 + 2"
13+
assert setup == ""
14+
assert code_type == "expression"
15+
16+
17+
def test_parse_code_single_statement():
18+
code, setup, code_type = parse_code("x = 5")
19+
20+
assert code == "x = 5"
21+
assert setup == ""
22+
assert code_type == "statement"
23+
24+
25+
def test_parse_code_multiline_with_expression_basic():
26+
code, setup, code_type = parse_code("x = 5\ny = 10\nx + y")
27+
28+
assert code == "x + y"
29+
assert setup == "x = 5\ny = 10"
30+
assert code_type == "expression"
31+
32+
33+
def test_parse_code_multiline_statement_only():
34+
code, setup, code_type = parse_code("x = 5\ny = 10\nz = x + y")
35+
36+
assert code == "x = 5\ny = 10\nz = x + y"
37+
assert setup == ""
38+
assert code_type == "statement"
39+
40+
41+
def test_parse_code_empty_code():
42+
code, setup, code_type = parse_code("")
43+
44+
assert code == ""
45+
assert setup == ""
46+
assert code_type == "statement"
47+
48+
49+
def test_parse_code_whitespace_only():
50+
code, setup, code_type = parse_code(" \n \t ")
51+
52+
assert code == " \n \t "
53+
assert setup == ""
54+
assert code_type == "statement"
55+
56+
57+
def test_parse_code_trailing_newlines_expression():
58+
code = """\
59+
x = 5
60+
y = 10
61+
x + y
62+
63+
64+
"""
65+
code, setup, code_type = parse_code(code)
66+
67+
assert code == "x + y"
68+
# strip() removes leading/trailing empty lines
69+
assert setup == "x = 5\ny = 10"
70+
assert code_type == "expression"
71+
72+
73+
def test_parse_code_trailing_whitespace_expression():
74+
code, setup, code_type = parse_code("2 + 2 \n\n ")
75+
76+
# strip() removes trailing whitespace
77+
assert code == "2 + 2"
78+
assert setup == ""
79+
assert code_type == "expression"
80+
81+
82+
def test_parse_code_leading_newlines_expression():
83+
code, setup, code_type = parse_code("\n\n\n2 + 2")
84+
85+
# Single expressions are returned as-is, not stripped
86+
assert code == "\n\n\n2 + 2"
87+
assert setup == ""
88+
assert code_type == "expression"
89+
90+
91+
def test_parse_code_multiline_trailing_newlines():
92+
code, setup, code_type = parse_code("x = 5\nx + 10\n\n")
93+
94+
assert code == "x + 10"
95+
assert setup == "x = 5"
96+
assert code_type == "expression"
97+
98+
99+
def test_parse_code_empty_list():
100+
code, setup, code_type = parse_code("[]")
101+
102+
assert code == "[]"
103+
assert setup == ""
104+
assert code_type == "expression"
105+
106+
107+
def test_filter_existing_imports_star_import():
108+
result = filter_existing_imports("from os import *", {"os": True})
109+
110+
assert result == "from os import *"
111+
112+
113+
def test_filter_existing_imports_relative():
114+
result = filter_existing_imports(
115+
"from ..models import User\nfrom ...core import Base", {}
116+
)
117+
118+
assert result == "from ..models import User\nfrom ...core import Base"
119+
120+
121+
def test_filter_existing_imports_invalid_raises_error():
122+
with pytest.raises(ValueError, match="Input must contain only import statements"):
123+
filter_existing_imports("x = 5", {})

0 commit comments

Comments
 (0)