1
+ import pytest
2
+ import math
3
+ from typing import Set , Any
4
+ import numpy as np
5
+ from ell .util .closure import (
6
+ lexical_closure ,
7
+ is_immutable_variable ,
8
+ should_import ,
9
+ get_referenced_names ,
10
+ is_function_called ,
11
+ )
12
+
13
+ def test_lexical_closure_simple_function ():
14
+ def simple_func (x ):
15
+ return x * 2
16
+
17
+ result , (source , dsrc ), uses = lexical_closure (simple_func )
18
+ assert "def simple_func(x):" in result
19
+ assert "return x * 2" in result
20
+ assert isinstance (uses , Set )
21
+
22
+ def test_lexical_closure_with_global ():
23
+ global_var = 10
24
+ def func_with_global ():
25
+ return global_var
26
+
27
+ result , _ , _ = lexical_closure (func_with_global )
28
+ assert "global_var = 10" in result
29
+ assert "def func_with_global():" in result
30
+
31
+ def test_lexical_closure_with_nested_function ():
32
+ def outer ():
33
+ def inner ():
34
+ return 42
35
+ return inner ()
36
+
37
+ result , _ , _ = lexical_closure (outer )
38
+ assert "def outer():" in result
39
+ assert "def inner():" in result
40
+ assert "return 42" in result
41
+
42
+ def test_lexical_closure_with_default_args ():
43
+ def func_with_default (x = 10 ):
44
+ return x
45
+
46
+ result , _ , _ = lexical_closure (func_with_default )
47
+ print (result )
48
+ assert "def func_with_default(x=10):" in result
49
+
50
+ @pytest .mark .parametrize ("value, expected" , [
51
+ (42 , True ),
52
+ ("string" , True ),
53
+ ((1 , 2 , 3 ), True ),
54
+ ([1 , 2 , 3 ], False ),
55
+ ({"a" : 1 }, False ),
56
+ ])
57
+ def test_is_immutable_variable (value , expected ):
58
+ assert is_immutable_variable (value ) == expected
59
+
60
+ def test_should_import ():
61
+ import os
62
+ assert should_import (os )
63
+
64
+ class DummyModule :
65
+ __name__ = "dummy"
66
+ dummy = DummyModule ()
67
+ assert not should_import (dummy )
68
+
69
+ def test_get_referenced_names ():
70
+ code = """
71
+ import math
72
+ result = math.sin(x) + math.cos(y)
73
+ """
74
+ referenced = get_referenced_names (code , "math" )
75
+ print (referenced )
76
+ assert "sin" in referenced
77
+ assert "cos" in referenced
78
+
79
+ def test_is_function_called ():
80
+ code = """
81
+ def foo():
82
+ pass
83
+
84
+ def bar():
85
+ foo()
86
+
87
+ x = 1 + 2
88
+ """
89
+ assert is_function_called ("foo" , code )
90
+ assert not is_function_called ("bar" , code )
91
+ assert not is_function_called ("nonexistent" , code )
92
+
93
+ # Addressing linter errors
94
+ def test_lexical_closure_signature ():
95
+ def dummy_func ():
96
+ pass
97
+
98
+ # Test that the function accepts None for these arguments
99
+ result , _ , _ = lexical_closure (dummy_func , already_closed = None , recursion_stack = None )
100
+ assert result # Just check that it doesn't raise an exception
101
+
102
+ def test_lexical_closure_uses_type ():
103
+ def dummy_func ():
104
+ pass
105
+
106
+ _ , _ , uses = lexical_closure (dummy_func , initial_call = True )
107
+ assert isinstance (uses , Set )
108
+ # You might want to add a more specific check for the content of 'uses'
0 commit comments