Skip to content

Commit abdc8eb

Browse files
committed
quick chat exmaple formatting fixed, added promtp formatting using black, fixed test closure. Closes #38
1 parent 4c834f2 commit abdc8eb

File tree

5 files changed

+118
-15
lines changed

5 files changed

+118
-15
lines changed

examples/quick_chat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ def format_message_history(message_history : List[Tuple[str, str]]) -> str:
3535
@ell.lm(model="gpt-4o-mini", temperature=0.3, max_tokens=20)
3636
def chat(message_history : List[Tuple[str, str]], *, personality : str):
3737

38-
return [
39-
ell.system(f"""You are
40-
{personality}.
38+
return [
39+
ell.system(f"""Here is your description.
40+
{personality}.
4141
42-
Your goal is to come up with a response to a chat. Only respond in one sentence (should be like a text message in informality.) Never use Emojis."""),
43-
ell.user(format_message_history(message_history)),
44-
]
42+
Your goal is to come up with a response to a chat. Only respond in one sentence (should be like a text message in informality.) Never use Emojis."""),
43+
ell.user(format_message_history(message_history)),
44+
]
4545

4646

4747

poetry.lock

Lines changed: 85 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ requests = "^2.32.3"
3737
typing-extensions = "^4.12.2"
3838

3939

40+
black = "^24.8.0"
4041
[tool.poetry.group.dev.dependencies]
4142
pytest = "^8.3.2"
4243

src/ell/util/closure.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def xD():
4040
import importlib.util
4141
import re
4242
from collections import deque
43+
import black
4344

4445
DELIM = "$$$$$$$$$$$$$$$$$$$$$$$$$"
4546
FORBIDDEN_NAMES = ["ell", "lstr"]
@@ -94,12 +95,26 @@ def lexical_closure(
9495
CLOSURE_SOURCE[hash(func)] = dirty_src
9596

9697
dsrc = _clean_src(dirty_src_without_func)
98+
99+
# Format the sorce and dsrc soruce using Black
100+
source = _format_source(source)
101+
dsrc = _format_source(dsrc)
102+
97103
fn_hash = _generate_function_hash(source, dsrc, func.__qualname__)
98104

99105
_update_ell_func(outer_ell_func, source, dsrc, globals_and_frees['globals'], globals_and_frees['frees'], fn_hash, uses)
100106

101107
return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses))
102108

109+
110+
def _format_source(source: str) -> str:
111+
"""Format the source code using Black."""
112+
try:
113+
return black.format_str(source, mode=black.Mode())
114+
except:
115+
# If Black formatting fails, return the original source
116+
return source
117+
103118
def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:
104119
"""Get global and free variables for a function."""
105120
globals_dict = collections.OrderedDict(dill.detect.globalvars(func))
@@ -331,7 +346,10 @@ def get_referenced_names(code: str, module_name: str):
331346

332347
def lexically_closured_source(func):
333348
_, fnclosure, uses = lexical_closure(func, initial_call=True, recursion_stack=[])
334-
return fnclosure, uses
349+
source, dsrc = fnclosure
350+
formatted_source = _format_source(source)
351+
formatted_dsrc = _format_source(dsrc)
352+
return (formatted_source, formatted_dsrc), uses
335353

336354
import ast
337355

tests/test_closure.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ def test_get_referenced_names():
7878

7979
def test_is_function_called():
8080
code = """
81-
def foo():
82-
pass
83-
84-
def bar():
85-
foo()
86-
87-
x = 1 + 2
81+
def foo():
82+
pass
83+
84+
def bar():
85+
foo()
86+
87+
x = 1 + 2
8888
"""
8989
assert is_function_called("foo", code)
9090
assert not is_function_called("bar", code)

0 commit comments

Comments
 (0)