55import sys
66from dataclasses import dataclass # noqa: I001
77from itertools import product
8- from typing import AsyncGenerator
8+ from typing import AsyncGenerator , Awaitable , Callable , ParamSpec , TypeVar
99
1010import termcolor
1111
12+ P = ParamSpec ("P" )
13+ R = TypeVar ("R" )
1214from invariant .analyzer .language .ast import (
1315 ArrayLiteral ,
1416 BinaryExpr ,
@@ -786,10 +788,6 @@ def _is_unknown(self, value):
786788 async def visit_FunctionCall (self , node : FunctionCall ):
787789 function = await self .visit (node .name )
788790 args = await asyncio .gather (* [self .visit (arg ) for arg in node .args ])
789-
790- # only call functions, once all parameters are known
791- if function is Unknown or any (self ._is_unknown (arg ) for arg in args ):
792- return Unknown
793791 kwarg_items = await asyncio .gather (
794792 * [
795793 asyncio .gather (self .visit (entry .key ), self .visit (entry .value ))
@@ -798,6 +796,14 @@ async def visit_FunctionCall(self, node: FunctionCall):
798796 )
799797 kwargs = {k : v for k , v in kwarg_items }
800798
799+ # only call functions, once all parameters are known
800+ if (
801+ function is Unknown
802+ or any (self ._is_unknown (arg ) for arg in args )
803+ or any (self ._is_unknown (v ) for v in kwargs .values ())
804+ ):
805+ return Unknown
806+
801807 if isinstance (function , Declaration ):
802808 return await self .visit_PredicateCall (function , args , ** kwargs )
803809 else :
@@ -932,6 +938,9 @@ async def visit_ListComprehension(self, node: ListComprehension):
932938 if iterable is None :
933939 return []
934940
941+ if iterable is Unknown :
942+ return Unknown
943+
935944 var_name = node .var_name .id if hasattr (node .var_name , "id" ) else node .var_name
936945 results = []
937946 original_vars = self .variable_store .copy ()
@@ -953,14 +962,21 @@ async def visit_ListComprehension(self, node: ListComprehension):
953962 self .variable_store = original_vars
954963 return results
955964
956- async def acall_function (self , function , * args , ** kwargs ):
965+ async def acall_function (
966+ self ,
967+ function : Callable [P , Awaitable [R ]] | Callable [P , R ] | CachedFunctionWrapper ,
968+ * args : P .args ,
969+ ** kwargs : P .kwargs ,
970+ ) -> R :
957971 ctx : EvaluationContext = self .evaluation_context
958972 # if function is a cached function wrapper, unwrap it (at this point we
959973 # are already caching it)
960- if isinstance (function , CachedFunctionWrapper ):
961- function = function .func
962- linked_function = ctx .link (function , None )
974+ func = function .func if isinstance (function , CachedFunctionWrapper ) else function
975+ linked_function = ctx .link (func , None )
963976 # also unwrap linked function, if a cached function wrapper
964- if isinstance (linked_function , CachedFunctionWrapper ):
965- linked_function = linked_function .func
966- return await ctx .acall_function (linked_function , args , ** kwargs )
977+ linked_function = (
978+ linked_function .func
979+ if isinstance (linked_function , CachedFunctionWrapper )
980+ else linked_function
981+ )
982+ return await ctx .acall_function (linked_function , * args , ** kwargs )
0 commit comments