@@ -83,7 +83,7 @@ def lexical_closure(
83
83
already_closed .add (hash (func ))
84
84
85
85
globals_and_frees = _get_globals_and_frees (func )
86
- dependencies , imports , modules = _process_dependencies (func , globals_and_frees , already_closed , recursion_stack )
86
+ dependencies , imports , modules = _process_dependencies (func , globals_and_frees , already_closed , recursion_stack , uses )
87
87
88
88
cur_src = _build_initial_source (imports , dependencies , source )
89
89
@@ -103,7 +103,7 @@ def lexical_closure(
103
103
fn_hash = _generate_function_hash (source , dsrc , func .__qualname__ )
104
104
105
105
_update_ell_func (outer_ell_func , source , dsrc , globals_and_frees ['globals' ], globals_and_frees ['frees' ], fn_hash , uses )
106
-
106
+
107
107
return (dirty_src , (source , dsrc ), ({fn_hash } if not initial_call and hasattr (outer_ell_func , "__ell_func__" ) else uses ))
108
108
109
109
@@ -117,7 +117,7 @@ def _format_source(source: str) -> str:
117
117
118
118
def _get_globals_and_frees (func : Callable ) -> Dict [str , Dict ]:
119
119
"""Get global and free variables for a function."""
120
- globals_dict = collections .OrderedDict (dill . detect . globalvars (func ))
120
+ globals_dict = collections .OrderedDict (globalvars (func ))
121
121
frees_dict = collections .OrderedDict (dill .detect .freevars (func ))
122
122
123
123
if isinstance (func , type ):
@@ -128,60 +128,71 @@ def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:
128
128
129
129
return {'globals' : globals_dict , 'frees' : frees_dict }
130
130
131
- def _process_dependencies (func , globals_and_frees , already_closed , recursion_stack ):
131
+ def _process_dependencies (func , globals_and_frees , already_closed , recursion_stack , uses ):
132
132
"""Process function dependencies."""
133
133
dependencies = []
134
134
modules = deque ()
135
135
imports = []
136
136
137
137
if isinstance (func , (types .FunctionType , types .MethodType )):
138
- _process_default_kwargs (func , dependencies , already_closed , recursion_stack )
138
+ _process_default_kwargs (func , dependencies , already_closed , recursion_stack , uses )
139
139
140
140
for var_name , var_value in {** globals_and_frees ['globals' ], ** globals_and_frees ['frees' ]}.items ():
141
- _process_variable (var_name , var_value , dependencies , modules , imports , already_closed , recursion_stack )
141
+ _process_variable (var_name , var_value , dependencies , modules , imports , already_closed , recursion_stack , uses )
142
142
143
143
return dependencies , imports , modules
144
144
145
- def _process_default_kwargs (func , dependencies , already_closed , recursion_stack ):
145
+ def _process_default_kwargs (func , dependencies , already_closed , recursion_stack , uses ):
146
146
"""Process default keyword arguments of a function."""
147
147
ps = inspect .signature (func ).parameters
148
148
default_kwargs = collections .OrderedDict ({k : v .default for k , v in ps .items () if v .default is not inspect .Parameter .empty })
149
149
for name , val in default_kwargs .items ():
150
- if name not in FORBIDDEN_NAMES :
150
+ try :
151
+ is_builtin = val .__class__ .__module__ == "builtins" or val .__class__ .__module__ == "__builtins__"
152
+ except :
153
+ is_builtin = False
154
+ if name not in FORBIDDEN_NAMES and not is_builtin :
151
155
try :
152
- dep , _ , _ = lexical_closure (type (val ), already_closed = already_closed , recursion_stack = recursion_stack .copy ())
156
+ dep , _ , _uses = lexical_closure (type (val ), already_closed = already_closed , recursion_stack = recursion_stack .copy ())
153
157
dependencies .append (dep )
158
+ uses .update (_uses )
154
159
except Exception as e :
155
160
_raise_error (f"Failed to capture the lexical closure of default parameter { name } " , e , recursion_stack )
156
161
157
- def _process_variable (var_name , var_value , dependencies , modules , imports , already_closed , recursion_stack ):
162
+ def _process_variable (var_name , var_value , dependencies , modules , imports , already_closed , recursion_stack , uses ):
158
163
"""Process a single variable."""
159
164
if isinstance (var_value , (types .FunctionType , type , types .MethodType )):
160
- _process_callable (var_name , var_value , dependencies , already_closed , recursion_stack )
165
+ _process_callable (var_name , var_value , dependencies , already_closed , recursion_stack , uses )
161
166
elif isinstance (var_value , types .ModuleType ):
162
- _process_module (var_name , var_value , modules , imports )
167
+ _process_module (var_name , var_value , modules , imports , uses )
163
168
elif isinstance (var_value , types .BuiltinFunctionType ):
164
169
imports .append (dill .source .getimport (var_value , alias = var_name ))
165
170
else :
166
- _process_other_variable (var_name , var_value , dependencies )
171
+ _process_other_variable (var_name , var_value , dependencies , uses )
167
172
168
- def _process_callable (var_name , var_value , dependencies , already_closed , recursion_stack ):
173
+ def _process_callable (var_name , var_value , dependencies , already_closed , recursion_stack , uses ):
169
174
"""Process a callable (function, method, or class)."""
170
- if var_name not in FORBIDDEN_NAMES :
175
+ try :
176
+ module_is_ell = 'ell' in inspect .getmodule (var_value ).__name__
177
+ except :
178
+ module_is_ell = False
179
+
180
+ if var_name not in FORBIDDEN_NAMES and not module_is_ell :
171
181
try :
172
- dep , _ , _ = lexical_closure (var_value , already_closed = already_closed , recursion_stack = recursion_stack .copy ())
182
+ dep , _ , _uses = lexical_closure (var_value , already_closed = already_closed , recursion_stack = recursion_stack .copy ())
173
183
dependencies .append (dep )
184
+ uses .update (_uses )
174
185
except Exception as e :
175
186
_raise_error (f"Failed to capture the lexical closure of global or free variable { var_name } " , e , recursion_stack )
176
187
177
- def _process_module (var_name , var_value , modules , imports ):
188
+ def _process_module (var_name , var_value , modules , imports , uses ):
178
189
"""Process a module."""
179
190
if should_import (var_value ):
180
191
imports .append (dill .source .getimport (var_value , alias = var_name ))
181
192
else :
182
193
modules .append ((var_name , var_value ))
183
194
184
- def _process_other_variable (var_name , var_value , dependencies ):
195
+ def _process_other_variable (var_name , var_value , dependencies , uses ):
185
196
"""Process variables that are not callables or modules."""
186
197
if isinstance (var_value , str ) and '\n ' in var_value :
187
198
dependencies .append (f"{ var_name } = '''{ var_value } '''" )
@@ -400,4 +411,72 @@ def is_function_called(func_name, source_code):
400
411
return True
401
412
402
413
# If we've gone through all the nodes and haven't found a call to the function, it's not called
403
- return False
414
+ return False
415
+
416
+ #!/usr/bin/env python
417
+ #
418
+ # Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
419
+ # Modified by: William Guss.
420
+ # Copyright (c) 2008-2016 California Institute of Technology.
421
+ # Copyright (c) 2016-2024 The Uncertainty Quantification Foundation.
422
+ # License: 3-clause BSD. The full license text is available at:
423
+ # - https://github.com/uqfoundation/dill/blob/master/LICENSE
424
+ from dill .detect import nestedglobals
425
+ import inspect
426
+
427
+ def globalvars (func , recurse = True , builtin = False ):
428
+ """get objects defined in global scope that are referred to by func
429
+
430
+ return a dict of {name:object}"""
431
+ while hasattr (func , "__ell_func__" ):
432
+ func = func .__ell_func__
433
+ if inspect .ismethod (func ): func = func .__func__
434
+ while hasattr (func , "__ell_func__" ):
435
+ func = func .__ell_func__
436
+ if inspect .isfunction (func ):
437
+ globs = vars (inspect .getmodule (sum )).copy () if builtin else {}
438
+ # get references from within closure
439
+ orig_func , func = func , set ()
440
+ for obj in orig_func .__closure__ or {}:
441
+ try :
442
+ cell_contents = obj .cell_contents
443
+ except ValueError : # cell is empty
444
+ pass
445
+ else :
446
+ _vars = globalvars (cell_contents , recurse , builtin ) or {}
447
+ func .update (_vars ) #XXX: (above) be wary of infinte recursion?
448
+ globs .update (_vars )
449
+ # get globals
450
+ globs .update (orig_func .__globals__ or {})
451
+ # get names of references
452
+ if not recurse :
453
+ func .update (orig_func .__code__ .co_names )
454
+ else :
455
+ func .update (nestedglobals (orig_func .__code__ ))
456
+ # find globals for all entries of func
457
+ for key in func .copy (): #XXX: unnecessary...?
458
+ nested_func = globs .get (key )
459
+ if nested_func is orig_func :
460
+ #func.remove(key) if key in func else None
461
+ continue #XXX: globalvars(func, False)?
462
+ func .update (globalvars (nested_func , True , builtin ))
463
+ elif inspect .iscode (func ):
464
+ globs = vars (inspect .getmodule (sum )).copy () if builtin else {}
465
+ #globs.update(globals())
466
+ if not recurse :
467
+ func = func .co_names # get names
468
+ else :
469
+ orig_func = func .co_name # to stop infinite recursion
470
+ func = set (nestedglobals (func ))
471
+ # find globals for all entries of func
472
+ for key in func .copy (): #XXX: unnecessary...?
473
+ if key is orig_func :
474
+ #func.remove(key) if key in func else None
475
+ continue #XXX: globalvars(func, False)?
476
+ nested_func = globs .get (key )
477
+ func .update (globalvars (nested_func , True , builtin ))
478
+ else :
479
+ return {}
480
+ #NOTE: if name not in __globals__, then we skip it...
481
+ return dict ((name ,globs [name ]) for name in func if name in globs )
482
+
0 commit comments