Skip to content

Commit 6156aec

Browse files
committed
fix bug with dependencies?
1 parent 16cccf2 commit 6156aec

File tree

4 files changed

+231
-121
lines changed

4 files changed

+231
-121
lines changed

src/ell/decorators/lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def decorator(
2424
fn: LMP,
2525
) -> InvocableLM:
2626
color = compute_color(fn)
27+
_under_fn = fn
2728

2829
@wraps(fn)
2930
def wrapper(
@@ -48,7 +49,7 @@ def wrapper(
4849

4950
# TODO: # we'll deal with type safety here later
5051
wrapper.__ell_lm_kwargs__ = lm_kwargs
51-
wrapper.__ell_func__ = fn
52+
wrapper.__ell_func__ = _under_fn
5253
wrapper.__ell_lm = True
5354
wrapper.__ell_exempt_from_tracking = exempt_from_tracking
5455
if exempt_from_tracking:

src/ell/util/closure.py

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def lexical_closure(
8383
already_closed.add(hash(func))
8484

8585
globals_and_frees = _get_globals_and_frees(func)
86-
dependencies, imports, modules = _process_dependencies(func, globals_and_frees, already_closed, recursion_stack)
86+
dependencies, imports, modules = _process_dependencies(func, globals_and_frees, already_closed, recursion_stack, uses)
8787

8888
cur_src = _build_initial_source(imports, dependencies, source)
8989

@@ -103,7 +103,7 @@ def lexical_closure(
103103
fn_hash = _generate_function_hash(source, dsrc, func.__qualname__)
104104

105105
_update_ell_func(outer_ell_func, source, dsrc, globals_and_frees['globals'], globals_and_frees['frees'], fn_hash, uses)
106-
106+
107107
return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses))
108108

109109

@@ -117,7 +117,7 @@ def _format_source(source: str) -> str:
117117

118118
def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:
119119
"""Get global and free variables for a function."""
120-
globals_dict = collections.OrderedDict(dill.detect.globalvars(func))
120+
globals_dict = collections.OrderedDict(globalvars(func))
121121
frees_dict = collections.OrderedDict(dill.detect.freevars(func))
122122

123123
if isinstance(func, type):
@@ -128,60 +128,71 @@ def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:
128128

129129
return {'globals': globals_dict, 'frees': frees_dict}
130130

131-
def _process_dependencies(func, globals_and_frees, already_closed, recursion_stack):
131+
def _process_dependencies(func, globals_and_frees, already_closed, recursion_stack, uses):
132132
"""Process function dependencies."""
133133
dependencies = []
134134
modules = deque()
135135
imports = []
136136

137137
if isinstance(func, (types.FunctionType, types.MethodType)):
138-
_process_default_kwargs(func, dependencies, already_closed, recursion_stack)
138+
_process_default_kwargs(func, dependencies, already_closed, recursion_stack, uses)
139139

140140
for var_name, var_value in {**globals_and_frees['globals'], **globals_and_frees['frees']}.items():
141-
_process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack)
141+
_process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack, uses)
142142

143143
return dependencies, imports, modules
144144

145-
def _process_default_kwargs(func, dependencies, already_closed, recursion_stack):
145+
def _process_default_kwargs(func, dependencies, already_closed, recursion_stack, uses):
146146
"""Process default keyword arguments of a function."""
147147
ps = inspect.signature(func).parameters
148148
default_kwargs = collections.OrderedDict({k: v.default for k, v in ps.items() if v.default is not inspect.Parameter.empty})
149149
for name, val in default_kwargs.items():
150-
if name not in FORBIDDEN_NAMES:
150+
try:
151+
is_builtin = val.__class__.__module__ == "builtins" or val.__class__.__module__ == "__builtins__"
152+
except:
153+
is_builtin = False
154+
if name not in FORBIDDEN_NAMES and not is_builtin:
151155
try:
152-
dep, _, _ = lexical_closure(type(val), already_closed=already_closed, recursion_stack=recursion_stack.copy())
156+
dep, _, _uses = lexical_closure(type(val), already_closed=already_closed, recursion_stack=recursion_stack.copy())
153157
dependencies.append(dep)
158+
uses.update(_uses)
154159
except Exception as e:
155160
_raise_error(f"Failed to capture the lexical closure of default parameter {name}", e, recursion_stack)
156161

157-
def _process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack):
162+
def _process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack , uses):
158163
"""Process a single variable."""
159164
if isinstance(var_value, (types.FunctionType, type, types.MethodType)):
160-
_process_callable(var_name, var_value, dependencies, already_closed, recursion_stack)
165+
_process_callable(var_name, var_value, dependencies, already_closed, recursion_stack, uses)
161166
elif isinstance(var_value, types.ModuleType):
162-
_process_module(var_name, var_value, modules, imports)
167+
_process_module(var_name, var_value, modules, imports, uses)
163168
elif isinstance(var_value, types.BuiltinFunctionType):
164169
imports.append(dill.source.getimport(var_value, alias=var_name))
165170
else:
166-
_process_other_variable(var_name, var_value, dependencies)
171+
_process_other_variable(var_name, var_value, dependencies, uses)
167172

168-
def _process_callable(var_name, var_value, dependencies, already_closed, recursion_stack):
173+
def _process_callable(var_name, var_value, dependencies, already_closed, recursion_stack, uses):
169174
"""Process a callable (function, method, or class)."""
170-
if var_name not in FORBIDDEN_NAMES:
175+
try:
176+
module_is_ell = 'ell' in inspect.getmodule(var_value).__name__
177+
except:
178+
module_is_ell = False
179+
180+
if var_name not in FORBIDDEN_NAMES and not module_is_ell:
171181
try:
172-
dep, _, _ = lexical_closure(var_value, already_closed=already_closed, recursion_stack=recursion_stack.copy())
182+
dep, _, _uses = lexical_closure(var_value, already_closed=already_closed, recursion_stack=recursion_stack.copy())
173183
dependencies.append(dep)
184+
uses.update(_uses)
174185
except Exception as e:
175186
_raise_error(f"Failed to capture the lexical closure of global or free variable {var_name}", e, recursion_stack)
176187

177-
def _process_module(var_name, var_value, modules, imports):
188+
def _process_module(var_name, var_value, modules, imports, uses):
178189
"""Process a module."""
179190
if should_import(var_value):
180191
imports.append(dill.source.getimport(var_value, alias=var_name))
181192
else:
182193
modules.append((var_name, var_value))
183194

184-
def _process_other_variable(var_name, var_value, dependencies):
195+
def _process_other_variable(var_name, var_value, dependencies, uses):
185196
"""Process variables that are not callables or modules."""
186197
if isinstance(var_value, str) and '\n' in var_value:
187198
dependencies.append(f"{var_name} = '''{var_value}'''")
@@ -400,4 +411,72 @@ def is_function_called(func_name, source_code):
400411
return True
401412

402413
# If we've gone through all the nodes and haven't found a call to the function, it's not called
403-
return False
414+
return False
415+
416+
#!/usr/bin/env python
417+
#
418+
# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
419+
# Modified by: William Guss.
420+
# Copyright (c) 2008-2016 California Institute of Technology.
421+
# Copyright (c) 2016-2024 The Uncertainty Quantification Foundation.
422+
# License: 3-clause BSD. The full license text is available at:
423+
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
424+
from dill.detect import nestedglobals
425+
import inspect
426+
427+
def globalvars(func, recurse=True, builtin=False):
428+
"""get objects defined in global scope that are referred to by func
429+
430+
return a dict of {name:object}"""
431+
while hasattr(func, "__ell_func__"):
432+
func = func.__ell_func__
433+
if inspect.ismethod(func): func = func.__func__
434+
while hasattr(func, "__ell_func__"):
435+
func = func.__ell_func__
436+
if inspect.isfunction(func):
437+
globs = vars(inspect.getmodule(sum)).copy() if builtin else {}
438+
# get references from within closure
439+
orig_func, func = func, set()
440+
for obj in orig_func.__closure__ or {}:
441+
try:
442+
cell_contents = obj.cell_contents
443+
except ValueError: # cell is empty
444+
pass
445+
else:
446+
_vars = globalvars(cell_contents, recurse, builtin) or {}
447+
func.update(_vars) #XXX: (above) be wary of infinte recursion?
448+
globs.update(_vars)
449+
# get globals
450+
globs.update(orig_func.__globals__ or {})
451+
# get names of references
452+
if not recurse:
453+
func.update(orig_func.__code__.co_names)
454+
else:
455+
func.update(nestedglobals(orig_func.__code__))
456+
# find globals for all entries of func
457+
for key in func.copy(): #XXX: unnecessary...?
458+
nested_func = globs.get(key)
459+
if nested_func is orig_func:
460+
#func.remove(key) if key in func else None
461+
continue #XXX: globalvars(func, False)?
462+
func.update(globalvars(nested_func, True, builtin))
463+
elif inspect.iscode(func):
464+
globs = vars(inspect.getmodule(sum)).copy() if builtin else {}
465+
#globs.update(globals())
466+
if not recurse:
467+
func = func.co_names # get names
468+
else:
469+
orig_func = func.co_name # to stop infinite recursion
470+
func = set(nestedglobals(func))
471+
# find globals for all entries of func
472+
for key in func.copy(): #XXX: unnecessary...?
473+
if key is orig_func:
474+
#func.remove(key) if key in func else None
475+
continue #XXX: globalvars(func, False)?
476+
nested_func = globs.get(key)
477+
func.update(globalvars(nested_func, True, builtin))
478+
else:
479+
return {}
480+
#NOTE: if name not in __globals__, then we skip it...
481+
return dict((name,globs[name]) for name in func if name in globs)
482+

tests/test_closure.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import wraps
12
import pytest
23
import math
34
from typing import Set, Any
@@ -9,6 +10,8 @@
910
get_referenced_names,
1011
is_function_called,
1112
)
13+
import ell
14+
1215

1316
def test_lexical_closure_simple_function():
1417
def simple_func(x):
@@ -45,6 +48,7 @@ def func_with_default(x=10):
4548

4649
result, _, _ = lexical_closure(func_with_default)
4750
print(result)
51+
4852
assert "def func_with_default(x=10):" in result
4953

5054
@pytest.mark.parametrize("value, expected", [
@@ -68,8 +72,8 @@ class DummyModule:
6872

6973
def test_get_referenced_names():
7074
code = """
71-
import math
72-
result = math.sin(x) + math.cos(y)
75+
import math
76+
result = math.sin(x) + math.cos(y)
7377
"""
7478
referenced = get_referenced_names(code, "math")
7579
print(referenced)
@@ -105,4 +109,31 @@ def dummy_func():
105109

106110
_, _, uses = lexical_closure(dummy_func, initial_call=True)
107111
assert isinstance(uses, Set)
108-
# You might want to add a more specific check for the content of 'uses'
112+
# You might want to add a more specific check for the content of 'uses'
113+
114+
115+
def test_lexical_closure_uses():
116+
117+
@ell.lm(model="gpt-4")
118+
def dependency_func():
119+
return "42"
120+
121+
122+
@ell.lm(model="gpt-4")
123+
def main_func():
124+
return dependency_func()
125+
126+
127+
# Check that uses is a set
128+
assert isinstance(main_func.__ell_uses__, set)
129+
130+
# Check that the set contains exactly one item
131+
assert dependency_func.__ell_hash__ in main_func.__ell_uses__
132+
assert len(main_func.__ell_uses__) == 1
133+
# Check that the item in the set starts with 'lmp-'
134+
assert list(main_func.__ell_uses__)[0].startswith('lmp-')
135+
assert len(dependency_func.__ell_uses__) == 0
136+
137+
138+
if __name__ == "__main__":
139+
test_lexical_closure_uses()

0 commit comments

Comments
 (0)