diff --git a/typed_python/SerializationContext.py b/typed_python/SerializationContext.py index b110be90..832bbb41 100644 --- a/typed_python/SerializationContext.py +++ b/typed_python/SerializationContext.py @@ -31,11 +31,31 @@ import types import traceback import logging +import numpy +import pickle _badModuleCache = set() +def pickledByStr(module_name: str, name: str) -> None: + """Generate the object given the module_name and name. + + This mimics pickle's behavior when given a string from __reduce__. The + string is interpreted as the name of a global variable, and pickle.whichmodules + is used to search the module namespace, generating module_name. + + Note that 'name' might contain '.' inside of it, since its a 'local name'. + """ + module = importlib.import_module(module_name) + + instance = module + for subName in name.split('.'): + instance = getattr(instance, subName) + + return instance + + def createFunctionWithLocalsAndGlobals(code, globals): if globals is None: globals = {} @@ -708,26 +728,30 @@ def walkCodeObject(code): return (createFunctionWithLocalsAndGlobals, args, representation) if not isinstance(inst, type) and hasattr(type(inst), '__reduce_ex__'): - res = inst.__reduce_ex__(4) + if isinstance(inst, numpy.ufunc): + res = inst.__name__ + else: + res = inst.__reduce_ex__(4) - # pickle supports a protocol where __reduce__ can return a string - # giving a global name. We'll already find that separately, so we - # don't want to handle it here. We ought to look at this in more detail - # however + # mimic pickle's behaviour when a string is received. if isinstance(res, str): - return None + name_tuple = (inst, res) + module_name = pickle.whichmodule(*name_tuple) + res = (pickledByStr, (module_name, res,), pickledByStr) return res if not isinstance(inst, type) and hasattr(type(inst), '__reduce__'): - res = inst.__reduce__() + if isinstance(inst, numpy.ufunc): + res = inst.__name__ + else: + res = inst.__reduce() - # pickle supports a protocol where __reduce__ can return a string - # giving a global name. We'll already find that separately, so we - # don't want to handle it here. We ought to look at this in more detail - # however + # mimic pickle's behaviour when a string is received. if isinstance(res, str): - return None + name_tuple = (inst, res) + module_name = pickle.whichmodule(*name_tuple) + res = (pickledByStr, (module_name, res,), pickledByStr) return res @@ -736,6 +760,9 @@ def walkCodeObject(code): def setInstanceStateFromRepresentation( self, instance, representation=None, itemIt=None, kvPairIt=None, setStateFun=None ): + if representation is pickledByStr: + return + if representation is reconstructTypeFunctionType: return diff --git a/typed_python/__init__.py b/typed_python/__init__.py index 6c4d7070..ca849577 100644 --- a/typed_python/__init__.py +++ b/typed_python/__init__.py @@ -79,8 +79,6 @@ from typed_python.lib.map import map # noqa from typed_python.lib.pmap import pmap # noqa from typed_python.lib.reduce import reduce # noqa -from typed_python.lib.timestamp import Timestamp # noqa -from typed_python.lib.datetime.date_time import UTC, NYC, TimeOfDay, DateTime, Date, PytzTimezone # noqa _types.initializeGlobalStatics() diff --git a/typed_python/compiler/binary_shared_object.py b/typed_python/compiler/binary_shared_object.py index 90089215..5c2e5765 100644 --- a/typed_python/compiler/binary_shared_object.py +++ b/typed_python/compiler/binary_shared_object.py @@ -26,8 +26,8 @@ class LoadedBinarySharedObject(LoadedModule): - def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariableDefinitions): - super().__init__(functionPointers, globalVariableDefinitions) + def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlobalVariableDefinitions): + super().__init__(functionPointers, serializedGlobalVariableDefinitions) self.binarySharedObject = binarySharedObject self.diskPath = diskPath @@ -36,15 +36,17 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariabl class BinarySharedObject: """Models a shared object library (.so) loadable on linux systems.""" - def __init__(self, binaryForm, functionTypes, globalVariableDefinitions): + def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies): """ Args: - binaryForm - a bytes object containing the actual compiled code for the module - globalVariableDefinitions - a map from name to GlobalVariableDefinition + binaryForm: a bytes object containing the actual compiled code for the module + serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition + globalDependencies: a dict from function linkname to the list of global variables it depends on """ self.binaryForm = binaryForm self.functionTypes = functionTypes - self.globalVariableDefinitions = globalVariableDefinitions + self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions + self.globalDependencies = globalDependencies self.hash = sha_hash(binaryForm) @property @@ -52,14 +54,14 @@ def definedSymbols(self): return self.functionTypes.keys() @staticmethod - def fromDisk(path, globalVariableDefinitions, functionNameToType): + def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): with open(path, "rb") as f: binaryForm = f.read() - return BinarySharedObject(binaryForm, functionNameToType, globalVariableDefinitions) + return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) @staticmethod - def fromModule(module, globalVariableDefinitions, functionNameToType): + def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): target_triple = llvm.get_process_triple() target = llvm.Target.from_triple(target_triple) target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default') @@ -80,7 +82,7 @@ def fromModule(module, globalVariableDefinitions, functionNameToType): ) with open(os.path.join(tf, "module.so"), "rb") as so_file: - return BinarySharedObject(so_file.read(), functionNameToType, globalVariableDefinitions) + return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) def load(self, storageDir): """Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer""" @@ -127,8 +129,7 @@ def loadFromPath(self, modulePath): self, modulePath, functionPointers, - self.globalVariableDefinitions + self.serializedGlobalVariableDefinitions ) - loadedModule.linkGlobalVariables() return loadedModule diff --git a/typed_python/compiler/compiler_cache.py b/typed_python/compiler/compiler_cache.py index a093fc70..26f7e2a2 100644 --- a/typed_python/compiler/compiler_cache.py +++ b/typed_python/compiler/compiler_cache.py @@ -15,9 +15,12 @@ import os import uuid import shutil -from typed_python.compiler.loaded_module import LoadedModule -from typed_python.compiler.binary_shared_object import BinarySharedObject +from typing import Optional, List + +from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject +from typed_python.compiler.directed_graph import DirectedGraph +from typed_python.compiler.typed_call_target import TypedCallTarget from typed_python.SerializationContext import SerializationContext from typed_python import Dict, ListOf @@ -52,148 +55,173 @@ def __init__(self, cacheDir): ensureDirExists(cacheDir) - self.loadedModules = Dict(str, LoadedModule)() + self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)() self.nameToModuleHash = Dict(str, str)() - self.modulesMarkedValid = set() - self.modulesMarkedInvalid = set() + self.moduleManifestsLoaded = set() for moduleHash in os.listdir(self.cacheDir): if len(moduleHash) == 40: self.loadNameManifestFromStoredModuleByHash(moduleHash) - def hasSymbol(self, linkName): - return linkName in self.nameToModuleHash + # the set of functions with an associated module in loadedBinarySharedObjects + self.targetsLoaded: Dict[str, TypedCallTarget] = {} - def markModuleHashInvalid(self, hashstr): - with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"): - pass + # the set of functions with linked and validated globals (i.e. ready to be run). + self.targetsValidated = set() - def loadForSymbol(self, linkName): - moduleHash = self.nameToModuleHash[linkName] + self.function_dependency_graph = DirectedGraph() + # dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions) + self.global_dependencies = Dict(str, ListOf(str))() + + def hasSymbol(self, linkName: str) -> bool: + """NB this will return True even if the linkName is ultimately unretrievable.""" + return linkName in self.nameToModuleHash - nameToTypedCallTarget = {} - nameToNativeFunctionType = {} + def getTarget(self, linkName: str) -> TypedCallTarget: + if not self.hasSymbol(linkName): + raise ValueError(f'symbol not found for linkName {linkName}') + self.loadForSymbol(linkName) + return self.targetsLoaded[linkName] - if not self.loadModuleByHash(moduleHash, nameToTypedCallTarget, nameToNativeFunctionType): - return None + def dependencies(self, linkName: str) -> Optional[List[str]]: + """Returns all the function names that `linkName` depends on""" + return list(self.function_dependency_graph.outgoing(linkName)) - return nameToTypedCallTarget, nameToNativeFunctionType + def loadForSymbol(self, linkName: str) -> None: + """Loads the whole module, and any submodules, into LoadedBinarySharedObjects""" + moduleHash = self.nameToModuleHash[linkName] - def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFunctionType): + self.loadModuleByHash(moduleHash) + + if linkName not in self.targetsValidated: + dependantFuncs = self.dependencies(linkName) + [linkName] + globalsToLink = {} # dict from modulehash to list of globals. + for funcName in dependantFuncs: + if funcName not in self.targetsValidated: + funcModuleHash = self.nameToModuleHash[funcName] + # append to the list of globals to link for a given module. TODO: optimise this, don't double-link. + globalsToLink[funcModuleHash] = globalsToLink.get(funcModuleHash, []) + self.global_dependencies.get(funcName, []) + + for moduleHash, globs in globalsToLink.items(): # this works because loadModuleByHash loads submodules too. + if globs: + definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x] + for x in globs + } + self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink) + if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink): + raise RuntimeError('failed to validate globals when loading:', linkName) + + self.targetsValidated.update(dependantFuncs) + + def loadModuleByHash(self, moduleHash: str) -> None: """Load a module by name. As we load, place all the newly imported typed call targets into 'nameToTypedCallTarget' so that the rest of the system knows what functions have been uncovered. """ - if moduleHash in self.loadedModules: - return True + if moduleHash in self.loadedBinarySharedObjects: + return targetDir = os.path.join(self.cacheDir, moduleHash) - try: - with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: - callTargets = SerializationContext().deserialize(f.read()) + # TODO (Will) - store these names as module consts, use one .dat only + with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: + callTargets = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: - globalVarDefs = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: + serializedGlobalVarDefs = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: - functionNameToNativeType = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: + functionNameToNativeType = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: - submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - except Exception: - self.markModuleHashInvalid(moduleHash) - return False + with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: + submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - if not LoadedModule.validateGlobalVariables(globalVarDefs): - self.markModuleHashInvalid(moduleHash) - return False + with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f: + dependency_edgelist = SerializationContext().deserialize(f.read()) + + with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f: + globalDependencies = SerializationContext().deserialize(f.read()) # load the submodules first for submodule in submodules: - if not self.loadModuleByHash( - submodule, - nameToTypedCallTarget, - nameToNativeFunctionType - ): - return False + self.loadModuleByHash(submodule) modulePath = os.path.join(targetDir, "module.so") loaded = BinarySharedObject.fromDisk( modulePath, - globalVarDefs, - functionNameToNativeType + serializedGlobalVarDefs, + functionNameToNativeType, + globalDependencies + ).loadFromPath(modulePath) - self.loadedModules[moduleHash] = loaded + self.loadedBinarySharedObjects[moduleHash] = loaded + + self.targetsLoaded.update(callTargets) - nameToTypedCallTarget.update(callTargets) - nameToNativeFunctionType.update(functionNameToNativeType) + assert not any(key in self.global_dependencies for key in globalDependencies) # should only happen if there's a hash collision. + self.global_dependencies.update(globalDependencies) - return True + # update the cache's dependency graph with our new edges. + for function_name, dependant_function_name in dependency_edgelist: + self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name) - def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies): + def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist): """Add new code to the compiler cache. Args: - binarySharedObject - a BinarySharedObject containing the actual assembler - we've compiled - nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us - the formal python types for all the objects - linkDependencies - a set of linknames we depend on directly. + binarySharedObject: a BinarySharedObject containing the actual assembler + we've compiled. + nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us + the formal python types for all the objects. + linkDependencies: a set of linknames we depend on directly. + dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the + module. """ dependentHashes = set() for name in linkDependencies: dependentHashes.add(self.nameToModuleHash[name]) - path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes) + path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes, dependencyEdgelist) - self.loadedModules[hashToUse] = ( + self.loadedBinarySharedObjects[hashToUse] = ( binarySharedObject.loadFromPath(os.path.join(path, "module.so")) ) for n in binarySharedObject.definedSymbols: self.nameToModuleHash[n] = hashToUse - def loadNameManifestFromStoredModuleByHash(self, moduleHash): - if moduleHash in self.modulesMarkedValid: - return True - - targetDir = os.path.join(self.cacheDir, moduleHash) + # link & validate all globals for the new module + self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables() + if not self.loadedBinarySharedObjects[hashToUse].validateGlobalVariables( + self.loadedBinarySharedObjects[hashToUse].serializedGlobalVariableDefinitions): + raise RuntimeError('failed to validate globals in new module:', hashToUse) - # ignore 'marked invalid' - if os.path.exists(os.path.join(targetDir, "marked_invalid")): - # just bail - don't try to read it now + def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None: + if moduleHash in self.moduleManifestsLoaded: + return - # for the moment, we don't try to clean up the cache, because - # we can't be sure that some process is not still reading the - # old files. - self.modulesMarkedInvalid.add(moduleHash) - return False + targetDir = os.path.join(self.cacheDir, moduleHash) with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: submodules = SerializationContext().deserialize(f.read(), ListOf(str)) for subHash in submodules: - if not self.loadNameManifestFromStoredModuleByHash(subHash): - self.markModuleHashInvalid(subHash) - return False + self.loadNameManifestFromStoredModuleByHash(subHash) with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f: self.nameToModuleHash.update( SerializationContext().deserialize(f.read(), Dict(str, str)) ) - self.modulesMarkedValid.add(moduleHash) - - return True + self.moduleManifestsLoaded.add(moduleHash) - def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules): + def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules, dependencyEdgelist): """Write out a disk representation of this module. This includes writing both the shared object, a manifest of the function names @@ -246,11 +274,17 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule # write the type manifest with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f: - f.write(SerializationContext().serialize(binarySharedObject.globalVariableDefinitions)) + f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions)) with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f: f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str))) + with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f: + f.write(SerializationContext().serialize(dependencyEdgelist)) # might need a listof + + with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.globalDependencies)) + try: os.rename(tempTargetDir, targetDir) except IOError: @@ -266,7 +300,7 @@ def function_pointer_by_name(self, linkName): if moduleHash is None: raise Exception("Can't find a module for " + linkName) - if moduleHash not in self.loadedModules: + if moduleHash not in self.loadedBinarySharedObjects: self.loadForSymbol(linkName) - return self.loadedModules[moduleHash].functionPointers[linkName] + return self.loadedBinarySharedObjects[moduleHash].functionPointers[linkName] diff --git a/typed_python/compiler/compiler_cache_test.py b/typed_python/compiler/compiler_cache_test.py index 81ad2f12..4d35fbc3 100644 --- a/typed_python/compiler/compiler_cache_test.py +++ b/typed_python/compiler/compiler_cache_test.py @@ -119,6 +119,7 @@ def test_compiler_cache_understands_type_changes(): VERSION1 = {'x.py': xmodule, 'y.py': ymodule} VERSION2 = {'x.py': xmodule.replace("1: 2", "1: 3"), 'y.py': ymodule} VERSION3 = {'x.py': xmodule.replace("int, int", "int, float").replace('1: 2', '1: 2.5'), 'y.py': ymodule} + VERSION4 = {'x.py': xmodule.replace("1: 2", "1: 4"), 'y.py': ymodule} assert '1: 3' in VERSION2['x.py'] @@ -134,6 +135,10 @@ def test_compiler_cache_understands_type_changes(): assert evaluateExprInFreshProcess(VERSION3, 'y.g(1)', compilerCacheDir) == 2.5 assert len(os.listdir(compilerCacheDir)) == 2 + # use the previously compiled module + assert evaluateExprInFreshProcess(VERSION4, 'y.g(1)', compilerCacheDir) == 4 + assert len(os.listdir(compilerCacheDir)) == 2 + @pytest.mark.skipif('sys.platform=="darwin"') def test_compiler_cache_handles_exceptions_properly(): @@ -362,12 +367,9 @@ def test_compiler_cache_handles_changed_types(): assert evaluateExprInFreshProcess(VERSION2, 'x.f(1)', compilerCacheDir) == 1 assert len(os.listdir(compilerCacheDir)) == 2 - badCt = 0 - for subdir in os.listdir(compilerCacheDir): - if 'marked_invalid' in os.listdir(os.path.join(compilerCacheDir, subdir)): - badCt += 1 - - assert badCt == 1 + # if we then use g1 again, it should not have been marked invalid and so remains accessible. + assert evaluateExprInFreshProcess(VERSION1, 'x.g1(1)', compilerCacheDir) == 1 + assert len(os.listdir(compilerCacheDir)) == 2 @pytest.mark.skipif('sys.platform=="darwin"') @@ -395,3 +397,34 @@ def test_ordering_is_stable_under_code_change(): ) assert names == names2 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_avoids_deserialization_error(): + xmodule1 = "\n".join([ + "@Entrypoint", + "def f():", + " return None", + "import badModule", + "@Entrypoint", + "def g():", + " print(badModule)", + " return f()", + ]) + + xmodule2 = "\n".join([ + "@Entrypoint", + "def f():", + " return", + ]) + + VERSION1 = {'x.py': xmodule1, 'badModule.py': ''} + VERSION2 = {'x.py': xmodule2} + + with tempfile.TemporaryDirectory() as compilerCacheDir: + evaluateExprInFreshProcess(VERSION1, 'x.g()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 1 + evaluateExprInFreshProcess(VERSION2, 'x.f()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 2 + evaluateExprInFreshProcess(VERSION1, 'x.g()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 2 diff --git a/typed_python/compiler/directed_graph.py b/typed_python/compiler/directed_graph.py index 5cf2f03e..7cc89b63 100644 --- a/typed_python/compiler/directed_graph.py +++ b/typed_python/compiler/directed_graph.py @@ -44,6 +44,10 @@ def hasEdge(self, source, dest): return False return dest in self.sourceToDest[source] + def clearOutgoing(self, node): + for child in list(self.outgoing(node)): + self.dropEdge(node, child) + def outgoing(self, node): return self.sourceToDest.get(node, set()) diff --git a/typed_python/compiler/global_variable_definition.py b/typed_python/compiler/global_variable_definition.py index 4f01c11f..4dbf34f8 100644 --- a/typed_python/compiler/global_variable_definition.py +++ b/typed_python/compiler/global_variable_definition.py @@ -79,3 +79,12 @@ def __init__(self, name, typ, metadata): self.name = name self.type = typ self.metadata = metadata + + def __eq__(self, other): + if not isinstance(other, GlobalVariableDefinition): + return False + + return self.name == other.name and self.type == other.type and self.metadata == other.metadata + + def __str__(self): + return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={pad(str(self.metadata))})" diff --git a/typed_python/compiler/llvm_compiler.py b/typed_python/compiler/llvm_compiler.py index 9579df16..f33e5edb 100644 --- a/typed_python/compiler/llvm_compiler.py +++ b/typed_python/compiler/llvm_compiler.py @@ -22,7 +22,7 @@ from typed_python.compiler.binary_shared_object import BinarySharedObject import ctypes -from typed_python import _types +from typed_python import _types, SerializationContext llvm.initialize() llvm.initialize_native_target() @@ -84,18 +84,14 @@ def create_execution_engine(inlineThreshold): class Compiler: - def __init__(self, inlineThreshold): + def __init__(self, inlineThreshold, compilerCache): self.engine, self.module_pass_manager = create_execution_engine(inlineThreshold) - self.converter = native_ast_to_llvm.Converter() + self.converter = native_ast_to_llvm.Converter(compilerCache) self.functions_by_name = {} self.inlineThreshold = inlineThreshold self.verbose = False self.optimize = True - def markExternal(self, functionNameToType): - """Provide type signatures for a set of external functions.""" - self.converter.markExternal(functionNameToType) - def mark_converter_verbose(self): self.converter.verbose = True @@ -121,17 +117,20 @@ def buildSharedObject(self, functions): self.engine.finalize_object() + serializedGlobalVariableDefinitions = {x: SerializationContext().serialize(y) for x, y in module.globalVariableDefinitions.items()} + return BinarySharedObject.fromModule( mod, - module.globalVariableDefinitions, + serializedGlobalVariableDefinitions, module.functionNameToType, + module.globalDependencies ) def function_pointer_by_name(self, name): return self.functions_by_name.get(name) def buildModule(self, functions): - """Compile a list of functions into a new module. + """Compile a list of functions into a new module. Only relevant if there is no compiler cache. Args: functions - a map from name to native_ast.Function @@ -187,4 +186,5 @@ def buildModule(self, functions): ) ) - return LoadedModule(native_function_pointers, module.globalVariableDefinitions) + serializedGlobalVariableDefinitions = {x: SerializationContext().serialize(y) for x, y in module.globalVariableDefinitions.items()} + return LoadedModule(native_function_pointers, serializedGlobalVariableDefinitions) diff --git a/typed_python/compiler/llvm_compiler_test.py b/typed_python/compiler/llvm_compiler_test.py index e10f9453..d914bae4 100644 --- a/typed_python/compiler/llvm_compiler_test.py +++ b/typed_python/compiler/llvm_compiler_test.py @@ -20,6 +20,8 @@ from typed_python.compiler.module_definition import ModuleDefinition from typed_python.compiler.global_variable_definition import GlobalVariableMetadata +from typed_python.test_util import evaluateExprInFreshProcess + import pytest import ctypes @@ -115,7 +117,7 @@ def test_create_binary_shared_object(): {'__test_f_2': f} ) - assert len(bso.globalVariableDefinitions) == 1 + assert len(bso.serializedGlobalVariableDefinitions) == 1 with tempfile.TemporaryDirectory() as tf: loaded = bso.load(tf) @@ -131,3 +133,28 @@ def test_create_binary_shared_object(): pointers[0].set(5) assert loaded.functionPointers['__test_f_2']() == 5 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_loaded_modules_persist(): + """ + Make sure that loaded modules are persisted in the converter state. + + We have to maintain these references to avoid surprise segfaults - if this test fails, + it should be because the GlobalVariableDefinition memory management has been refactored. + """ + + # compile a module + xmodule = "\n".join([ + "@Entrypoint", + "def f(x):", + " return x + 1", + "@Entrypoint", + "def g(x):", + " return f(x) * 100", + "g(1000)", + "def get_loaded_modules():", + " return len(Runtime.singleton().converter.loadedUncachedModules)" + ]) + VERSION1 = {'x.py': xmodule} + assert evaluateExprInFreshProcess(VERSION1, 'x.get_loaded_modules()') == 1 diff --git a/typed_python/compiler/loaded_module.py b/typed_python/compiler/loaded_module.py index c03ab321..ffb2112c 100644 --- a/typed_python/compiler/loaded_module.py +++ b/typed_python/compiler/loaded_module.py @@ -1,39 +1,48 @@ +from typing import Dict, List from typed_python.compiler.module_definition import ModuleDefinition -from typed_python import PointerTo, ListOf, Class +from typed_python import PointerTo, ListOf, Class, SerializationContext from typed_python import _types class LoadedModule: """Represents a bundle of compiled functions that are now loaded in memory. - Members: functionPointers - a map from name to NativeFunctionPointer giving the public interface of the module - globalVariableDefinitions - a map from name to GlobalVariableDefinition + serializedGlobalVariableDefinitions - a map from LLVM-assigned global name to serialized GlobalVariableDefinition giving the loadable strings """ GET_GLOBAL_VARIABLES_NAME = ModuleDefinition.GET_GLOBAL_VARIABLES_NAME - def __init__(self, functionPointers, globalVariableDefinitions): + def __init__(self, functionPointers, serializedGlobalVariableDefinitions): self.functionPointers = functionPointers + assert ModuleDefinition.GET_GLOBAL_VARIABLES_NAME in self.functionPointers + + self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions + self.orderedDefs = [ + self.serializedGlobalVariableDefinitions[name] for name in sorted(self.serializedGlobalVariableDefinitions) + ] + self.orderedDefNames = sorted(list(self.serializedGlobalVariableDefinitions.keys())) + self.pointers = ListOf(PointerTo(int))() + self.pointers.resize(len(self.orderedDefs)) - self.globalVariableDefinitions = globalVariableDefinitions + self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](self.pointers.pointerUnsafe(0)) + + self.installedGlobalVariableDefinitions = {} @staticmethod - def validateGlobalVariables(globalVariableDefinitions): + def validateGlobalVariables(serializedGlobalVariableDefinitions: Dict[str, bytes]) -> bool: """Check that each global variable definition is sensible. - Sometimes we may successfully deserialize a global variable from a cached module, but then some dictionary member is not valid because it was removed or has the wrong type. In this case, we need to evict this module from the cache because it's no longer valid. Args: - globalVariableDefinitions - a dict from string to GlobalVariableMetadata + serializedGlobalVariableDefinitions: a dict from string to a serialized GlobalVariableMetadata """ - for gvd in globalVariableDefinitions.values(): - meta = gvd.metadata - + for gvd in serializedGlobalVariableDefinitions.values(): + meta = SerializationContext().deserialize(gvd).metadata if meta.matches.PointerToTypedPythonObjectAsMemberOfDict: if not isinstance(meta.sourceDict, dict): return False @@ -54,54 +63,47 @@ def validateGlobalVariables(globalVariableDefinitions): return True - def linkGlobalVariables(self): - """Walk over all global variables in the module and make sure they are populated. - + def linkGlobalVariables(self, variable_names: List[str] = None) -> None: + """Walk over all global variables in `variable_names` and make sure they are populated. Each module has a bunch of global variables that contain references to things like type objects, string objects, python module members, etc. - - The metadata about these is stored in 'self.globalVariableDefinitions' whose keys + The metadata about these is stored in 'self.serializedGlobalVariableDefinitions' whose keys are names and whose values are GlobalVariableMetadata instances. - Every module we compile exposes a member named ModuleDefinition.GET_GLOBAL_VARIABLES_NAME which takes a pointer to a list of pointers and fills it out with the global variables. - When the module is loaded, all the variables are initialized to zero. This function walks over them and populates them, effectively linking them into the current binary. """ - assert ModuleDefinition.GET_GLOBAL_VARIABLES_NAME in self.functionPointers - - orderedDefs = [ - self.globalVariableDefinitions[name] for name in sorted(self.globalVariableDefinitions) - ] - pointers = ListOf(PointerTo(int))() - pointers.resize(len(orderedDefs)) + if variable_names is None: + i_vals = range(len(self.orderedDefs)) + else: + i_vals = [self.orderedDefNames.index(x) for x in variable_names] - self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](pointers.pointerUnsafe(0)) + for i in i_vals: + assert self.pointers[i], f"Failed to get a pointer to {self.orderedDefs[i].name}" - for i in range(len(orderedDefs)): - assert pointers[i], f"Failed to get a pointer to {orderedDefs[i].name}" + meta = SerializationContext().deserialize(self.orderedDefs[i]).metadata - meta = orderedDefs[i].metadata + self.installedGlobalVariableDefinitions[i] = meta if meta.matches.StringConstant: - pointers[i].cast(str).initialize(meta.value) + self.pointers[i].cast(str).initialize(meta.value) if meta.matches.IntegerConstant: - pointers[i].cast(int).initialize(meta.value) + self.pointers[i].cast(int).initialize(meta.value) elif meta.matches.BytesConstant: - pointers[i].cast(bytes).initialize(meta.value) + self.pointers[i].cast(bytes).initialize(meta.value) elif meta.matches.PointerToPyObject: - pointers[i].cast(object).initialize(meta.value) + self.pointers[i].cast(object).initialize(meta.value) elif meta.matches.PointerToTypedPythonObject: - pointers[i].cast(meta.type).initialize(meta.value) + self.pointers[i].cast(meta.type).initialize(meta.value) elif meta.matches.PointerToTypedPythonObjectAsMemberOfDict: - pointers[i].cast(meta.type).initialize(meta.sourceDict[meta.name]) + self.pointers[i].cast(meta.type).initialize(meta.sourceDict[meta.name]) elif meta.matches.ClassMethodDispatchSlot: slotIx = _types.allocateClassMethodDispatch( @@ -111,17 +113,17 @@ def linkGlobalVariables(self): meta.argTupleType, meta.kwargTupleType ) - pointers[i].cast(int).initialize(slotIx) + self.pointers[i].cast(int).initialize(slotIx) elif meta.matches.IdOfPyObject: - pointers[i].cast(int).initialize(id(meta.value)) + self.pointers[i].cast(int).initialize(id(meta.value)) elif meta.matches.ClassVtable: - pointers[i].cast(int).initialize( + self.pointers[i].cast(int).initialize( _types._vtablePointer(meta.value) ) elif meta.matches.RawTypePointer: - pointers[i].cast(int).initialize( + self.pointers[i].cast(int).initialize( _types.getTypePointer(meta.value) ) diff --git a/typed_python/compiler/module_definition.py b/typed_python/compiler/module_definition.py index 4e9b35fd..cadbb2ec 100644 --- a/typed_python/compiler/module_definition.py +++ b/typed_python/compiler/module_definition.py @@ -18,15 +18,19 @@ class ModuleDefinition: """A single module of compiled llvm code. - Members: - moduleText - a string containing the llvm IR for the module - functionList - a list of the names of exported functions - globalDefinitions - a dict from name to a GlobalDefinition + Attributes: + moduleText (str): a string containing the llvm IR for the module + functionList (list): a list of the names of exported functions + globalDefinitions (dict): a dict from name to a GlobalDefinition + globalDependencies (dict): a dict from function link_name to a list of globals the + function depends on + hash (str): The module hash, generated from the llvm IR. """ GET_GLOBAL_VARIABLES_NAME = ".get_global_variables" - def __init__(self, moduleText, functionNameToType, globalVariableDefinitions): + def __init__(self, moduleText, functionNameToType, globalVariableDefinitions, globalDependencies): self.moduleText = moduleText self.functionNameToType = functionNameToType self.globalVariableDefinitions = globalVariableDefinitions + self.globalDependencies = globalDependencies self.hash = sha_hash(moduleText) diff --git a/typed_python/compiler/native_ast_to_llvm.py b/typed_python/compiler/native_ast_to_llvm.py index 4850ed95..bbd2027d 100644 --- a/typed_python/compiler/native_ast_to_llvm.py +++ b/typed_python/compiler/native_ast_to_llvm.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typed_python.compiler.native_ast as native_ast -from typed_python.compiler.module_definition import ModuleDefinition -from typed_python.compiler.global_variable_definition import GlobalVariableDefinition import llvmlite.ir import os - +import typed_python.compiler.native_ast as native_ast +from typed_python.compiler.global_variable_definition import GlobalVariableDefinition +from typed_python.compiler.module_definition import ModuleDefinition +from typing import Dict llvm_i8ptr = llvmlite.ir.IntType(8).as_pointer() llvm_i8 = llvmlite.ir.IntType(8) llvm_i32 = llvmlite.ir.IntType(32) @@ -501,19 +501,19 @@ def __init__(self, module, globalDefinitions, globalDefinitionLlvmValues, - function, converter, builder, arg_assignments, output_type, - external_function_references + external_function_references, + compilerCache, ): - self.function = function # dict from name to GlobalVariableDefinition self.globalDefinitions = globalDefinitions self.globalDefinitionLlvmValues = globalDefinitionLlvmValues - + # a list of the global LLVM names that the function depends on. + self.global_names = [] self.module = module self.converter = converter self.builder = builder @@ -522,6 +522,7 @@ def __init__(self, self.external_function_references = external_function_references self.tags_initialized = {} self.stack_slots = {} + self.compilerCache = compilerCache def tags_as(self, new_tags): class scoper(): @@ -631,7 +632,16 @@ def generate_exception_and_store_value(self, llvm_pointer_val): ) return self.builder.bitcast(exception_ptr, llvm_i8ptr) - def namedCallTargetToLLVM(self, target): + def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVMValue: + """ + Generate llvm IR code for a given target. + + There are three options for code generation: + 1. The target is external, i.e something like pyobj_len, np_add_traceback - system-level functions. We add to + external_function_references. + 2. The function is in function_definitions, in which case we grab the function definition and make an inlining decision. + 3. We have a compiler cache, and the function is in it. We add to external_function_references. + """ if target.external: if target.name not in self.external_function_references: func_type = llvmlite.ir.FunctionType( @@ -648,7 +658,23 @@ def namedCallTargetToLLVM(self, target): llvmlite.ir.Function(self.module, func_type, target.name) func = self.external_function_references[target.name] - elif target.name in self.converter._externallyDefinedFunctionTypes: + elif target.name in self.converter._function_definitions: + func = self.converter._functions_by_name[target.name] + if func.module is not self.module: + # first, see if we'd like to inline this module + if ( + self.converter.totalFunctionComplexity(target.name) < CROSS_MODULE_INLINE_COMPLEXITY + ): + func = self.converter.repeatFunctionInModule(target.name, self.module) + else: + if target.name not in self.external_function_references: + self.external_function_references[target.name] = \ + llvmlite.ir.Function(self.module, func.function_type, func.name) + + func = self.external_function_references[target.name] + else: + # TODO (Will): decide whether to inline cached code + assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name) # this function is defined in a shared object that we've loaded from a prior # invocation if target.name not in self.external_function_references: @@ -665,22 +691,6 @@ def namedCallTargetToLLVM(self, target): ) func = self.external_function_references[target.name] - else: - func = self.converter._functions_by_name[target.name] - - if func.module is not self.module: - # first, see if we'd like to inline this module - if ( - self.converter.totalFunctionComplexity(target.name) < CROSS_MODULE_INLINE_COMPLEXITY - and self.converter.canBeInlined(target.name) - ): - func = self.converter.repeatFunctionInModule(target.name, self.module) - else: - if target.name not in self.external_function_references: - self.external_function_references[target.name] = \ - llvmlite.ir.Function(self.module, func.function_type, func.name) - - func = self.external_function_references[target.name] return TypedLLVMValue( func, @@ -801,6 +811,7 @@ def _convert(self, expr): return self.stack_slots[expr.name] if expr.matches.GlobalVariable: + self.global_names.append(expr.name) if expr.name in self.globalDefinitions: assert expr.metadata == self.globalDefinitions[expr.name].metadata, ( expr.metadata, self.globalDefinitions[expr.name].metadata @@ -1484,15 +1495,11 @@ def define(fname, output, inputs, vararg=False): class Converter: - def __init__(self): + def __init__(self, compilerCache=None): object.__init__(self) self._modules = {} - self._functions_by_name = {} - self._function_definitions = {} - - # a map from function name to function type for functions that - # are defined in external shared objects and linked in to this one. - self._externallyDefinedFunctionTypes = {} + self._functions_by_name: Dict[str, llvmlite.ir.Function] = {} + self._function_definitions: Dict[str, native_ast.Function] = {} # total number of instructions in each function, by name self._function_complexity = {} @@ -1502,17 +1509,12 @@ def __init__(self): self._printAllNativeCalls = os.getenv("TP_COMPILER_LOG_NATIVE_CALLS") self.verbose = False - def markExternal(self, functionNameToType): - """Provide type signatures for a set of external functions.""" - self._externallyDefinedFunctionTypes.update(functionNameToType) - - def canBeInlined(self, name): - return name not in self._externallyDefinedFunctionTypes + self.compilerCache = compilerCache def totalFunctionComplexity(self, name): """Return the total number of instructions contained in a function. - The function must already have been defined in a prior parss. We use this + The function must already have been defined in a prior pass. We use this information to decide which functions to repeat in new module definitions. """ if name in self._function_complexity: @@ -1546,9 +1548,7 @@ def repeatFunctionInModule(self, name, module): assert isinstance(funcType, llvmlite.ir.FunctionType) self._functions_by_name[name] = llvmlite.ir.Function(module, funcType, name) - self._inlineRequests.append(name) - return self._functions_by_name[name] def add_functions(self, names_to_definitions): @@ -1604,7 +1604,8 @@ def add_functions(self, names_to_definitions): globalDefinitions = {} globalDefinitionsLlvmValues = {} - + # we need a separate dictionary owing to the possibility of global var reuse across functions. + globalDependencies = {} while names_to_definitions: for name in sorted(names_to_definitions): definition = names_to_definitions.pop(name) @@ -1628,12 +1629,12 @@ def add_functions(self, names_to_definitions): module, globalDefinitions, globalDefinitionsLlvmValues, - func, self, builder, arg_assignments, definition.output_type, - external_function_references + external_function_references, + self.compilerCache, ) func_converter.setup() @@ -1642,6 +1643,8 @@ def add_functions(self, names_to_definitions): func_converter.finalize() + globalDependencies[func.name] = func_converter.global_names + if res is not None: assert res.llvm_value is None if definition.output_type != native_ast.Void: @@ -1675,7 +1678,8 @@ def add_functions(self, names_to_definitions): return ModuleDefinition( str(module), functionTypes, - globalDefinitions + globalDefinitions, + globalDependencies ) def defineGlobalMetadataAccessor(self, module, globalDefinitions, globalDefinitionsLlvmValues): diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index c9bb2748..237fdb1d 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -17,6 +17,7 @@ from typed_python.hash import Hash from types import ModuleType +from typing import Dict from typed_python import Class import typed_python.python_ast as python_ast import typed_python._types as _types @@ -57,6 +58,26 @@ def __init__(self): # (priority, node) pairs that need to recompute self._dirty_inflight_functions_with_order = SortedSet(key=lambda pair: pair[0]) + def reachableInSet(self, rootIdentity, activeSet): + """Produce the subset of 'activeSet' that are reachable from 'rootIdentity'""" + reachable = set() + + def walk(node): + if node in reachable or node not in activeSet: + return + + reachable.add(node) + + for child in self._dependencies.outgoing(node): + walk(child) + + walk(rootIdentity) + + return reachable + + def clearOutgoingEdgesFor(self, identity): + self._dependencies.clearOutgoing(identity) + def dropNode(self, node): self._dependencies.dropNode(node, False) if node in self._identity_levels: @@ -72,19 +93,21 @@ def getNextDirtyNode(self): return identity - def addRoot(self, identity): + def addRoot(self, identity, dirty=True): if identity not in self._identity_levels: self._identity_levels[identity] = 0 - self.markDirty(identity) + if dirty: + self.markDirty(identity) - def addEdge(self, caller, callee): + def addEdge(self, caller, callee, dirty=True): if caller not in self._identity_levels: raise Exception(f"unknown identity {caller} found in the graph") if callee not in self._identity_levels: self._identity_levels[callee] = self._identity_levels[caller] + 1 - self.markDirty(callee, isNew=True) + if dirty: + self.markDirty(callee, isNew=True) self._dependencies.addEdge(caller, callee) @@ -122,21 +145,21 @@ def __init__(self, llvmCompiler, compilerCache): self.llvmCompiler = llvmCompiler self.compilerCache = compilerCache + # all LoadedModule objects that we have created. We need to keep them alive so + # that any python metadata objects the've created stay alive as well. Ultimately, this + # may not be the place we put these objects (for instance, you could imagine a + # 'dummy' compiler cache or something). But for now, we need to keep them alive. + self.loadedUncachedModules = [] + # if True, then insert additional code to check for undefined behavior. self.generateDebugChecks = False - # all link names for which we have a definition. - self._allDefinedNames = set() - - # all names we loaded from the cache - self._allCachedNames = set() - self._link_name_for_identity = {} self._identity_for_link_name = {} - self._definitions = {} - self._targets = {} + self._definitions: Dict[str, native_ast.Function] = {} + self._targets: Dict[str, TypedCallTarget] = {} self._inflight_definitions = {} - self._inflight_function_conversions = {} + self._inflight_function_conversions: Dict[str, FunctionConversionContext] = {} self._identifier_to_pyfunc = {} self._times_calculated = {} @@ -194,16 +217,19 @@ def buildAndLinkNewModule(self): if self.compilerCache is None: loadedModule = self.llvmCompiler.buildModule(targets) loadedModule.linkGlobalVariables() + self.loadedUncachedModules.append(loadedModule) return # get a set of function names that we depend on externallyUsed = set() + dependency_edgelist = [] for funcName in targets: ident = self._identity_for_link_name.get(funcName) if ident is not None: for dep in self._dependencies.getNamesDependedOn(ident): depLN = self._link_name_for_identity.get(dep) + dependency_edgelist.append([funcName, depLN]) if depLN not in targets: externallyUsed.add(depLN) @@ -211,8 +237,9 @@ def buildAndLinkNewModule(self): self.compilerCache.addModule( binary, - {name: self._targets[name] for name in targets if name in self._targets}, - externallyUsed + {name: self.getTarget(name) for name in targets if self.hasTarget(name)}, + externallyUsed, + dependency_edgelist, ) def extract_new_function_definitions(self): @@ -274,35 +301,31 @@ def defineLinkName(self, identity, linkName): self._link_name_for_identity[identity] = linkName self._identity_for_link_name[linkName] = identity - if linkName in self._allDefinedNames: - return False - - self._allDefinedNames.add(linkName) + def hasTarget(self, linkName): + return self.getTarget(linkName) is not None - self._loadFromCompilerCache(linkName) + def deleteTarget(self, linkName): + self._targets.pop(linkName) - return True + def setTarget(self, linkName, target): + assert(isinstance(target, TypedCallTarget)) + self._targets[linkName] = target - def _loadFromCompilerCache(self, linkName): - if self.compilerCache: - if self.compilerCache.hasSymbol(linkName): - callTargetsAndTypes = self.compilerCache.loadForSymbol(linkName) - - if callTargetsAndTypes is not None: - newTypedCallTargets, newNativeFunctionTypes = callTargetsAndTypes + def getTarget(self, linkName): + if linkName in self._targets: + return self._targets[linkName] - self._targets.update(newTypedCallTargets) - self.llvmCompiler.markExternal(newNativeFunctionTypes) + if self.compilerCache is not None and self.compilerCache.hasSymbol(linkName): + return self.compilerCache.getTarget(linkName) - self._allDefinedNames.update(newNativeFunctionTypes) - self._allCachedNames.update(newNativeFunctionTypes) + return None def defineNonPythonFunction(self, name, identityTuple, context): """Define a non-python generating function (if we haven't defined it before already) name - the name to actually give the function. identityTuple - a unique (sha)hashable tuple - context - a FunctionConvertsionContext lookalike + context - a FunctionConversionContext lookalike returns a TypedCallTarget, or None if it's not known yet """ @@ -311,32 +334,36 @@ def defineNonPythonFunction(self, name, identityTuple, context): self.defineLinkName(identity, linkName) + target = self.getTarget(linkName) + if self._currentlyConverting is not None: - self._dependencies.addEdge(self._currentlyConverting, identity) + self._dependencies.addEdge(self._currentlyConverting, identity, dirty=(target is None)) else: - self._dependencies.addRoot(identity) + self._dependencies.addRoot(identity, dirty=(target is None)) - if linkName in self._targets: - return self._targets.get(linkName) + if target is not None: + return target self._inflight_function_conversions[identity] = context if context.knownOutputType() is not None or context.alwaysRaises(): - self._targets[linkName] = self.getTypedCallTarget( - name, - context.getInputTypes(), - context.knownOutputType(), - alwaysRaises=context.alwaysRaises(), - functionMetadata=context.functionMetadata + self.setTarget( + linkName, + self.getTypedCallTarget( + name, + context.getInputTypes(), + context.knownOutputType(), + alwaysRaises=context.alwaysRaises(), + functionMetadata=context.functionMetadata, + ) ) if self._currentlyConverting is None: # force the function to resolve immediately self._resolveAllInflightFunctions() - self._installInflightFunctions(name) - self._inflight_function_conversions.clear() + self._installInflightFunctions(identity) - return self._targets.get(linkName) + return self.getTarget(linkName) def defineNativeFunction(self, name, identity, input_types, output_type, generatingFunction): """Define a native function if we haven't defined it before already. @@ -459,13 +486,19 @@ def generateCallConverter(self, callTarget: TypedCallTarget): identifier = "call_converter_" + callTarget.name linkName = callTarget.name + ".dispatch" - if linkName in self._allDefinedNames: + # # we already made a definition for this in this process so don't do it again + if linkName in self._definitions: return linkName - self._loadFromCompilerCache(linkName) - if linkName in self._allDefinedNames: + # # we already defined it in another process so don't do it again + if self.compilerCache is not None and self.compilerCache.hasSymbol(linkName): return linkName + # N.B. there aren't targets for call converters. We make the definition directly. + + # if self.getTarget(linkName): + # return linkName + args = [] for i in range(len(callTarget.input_types)): if not callTarget.input_types[i].is_empty: @@ -503,7 +536,6 @@ def generateCallConverter(self, callTarget: TypedCallTarget): self._link_name_for_identity[identifier] = linkName self._identity_for_link_name[linkName] = identifier - self._allDefinedNames.add(linkName) self._definitions[linkName] = definition self._new_native_functions.add(linkName) @@ -516,10 +548,6 @@ def _resolveAllInflightFunctions(self): if not identity: return - linkName = self._link_name_for_identity[identity] - if linkName in self._allCachedNames: - continue - functionConverter = self._inflight_function_conversions[identity] hasDefinitionBeforeConversion = identity in self._inflight_definitions @@ -529,6 +557,10 @@ def _resolveAllInflightFunctions(self): self._times_calculated[identity] = self._times_calculated.get(identity, 0) + 1 + # this calls back into convert with dependencies + # they get registered as dirty + self._dependencies.clearOutgoingEdgesFor(identity) + nativeFunction, actual_output_type = functionConverter.convertToNativeFunction() if nativeFunction is not None: @@ -537,9 +569,8 @@ def _resolveAllInflightFunctions(self): for i in self._inflight_function_conversions: if i in self._link_name_for_identity: name = self._link_name_for_identity[i] - if name in self._targets: - self._targets.pop(name) - self._allDefinedNames.discard(name) + if self.hasTarget(name): + self.deleteTarget(name) ln = self._link_name_for_identity.pop(i) self._identity_for_link_name.pop(ln) @@ -567,12 +598,15 @@ def _resolveAllInflightFunctions(self): name = self._link_name_for_identity[identity] - self._targets[name] = self.getTypedCallTarget( + self.setTarget( name, - functionConverter._input_types, - actual_output_type, - alwaysRaises=functionConverter.alwaysRaises(), - functionMetadata=functionConverter.functionMetadata + self.getTypedCallTarget( + name, + functionConverter._input_types, + actual_output_type, + alwaysRaises=functionConverter.alwaysRaises(), + functionMetadata=functionConverter.functionMetadata, + ), ) if dirtyUpstream: @@ -860,13 +894,15 @@ def convert( if assertIsRoot: assert isRoot + target = self.getTarget(name) + if self._currentlyConverting is not None: - self._dependencies.addEdge(self._currentlyConverting, identity) + self._dependencies.addEdge(self._currentlyConverting, identity, dirty=(target is None)) else: - self._dependencies.addRoot(identity) + self._dependencies.addRoot(identity, dirty=(target is None)) - if name in self._targets: - return self._targets[name] + if target is not None: + return target if identity not in self._inflight_function_conversions: functionConverter = self.createConversionContext( @@ -880,14 +916,13 @@ def convert( output_type, conversionType ) - self._inflight_function_conversions[identity] = functionConverter if isRoot: try: self._resolveAllInflightFunctions() - self._installInflightFunctions(name) - return self._targets[name] + self._installInflightFunctions(identity) + return self.getTarget(name) finally: self._inflight_function_conversions.clear() @@ -897,12 +932,12 @@ def convert( # target with an output type and we can return that. Otherwise we have to # return None, which will cause callers to replace this with a throw # until we have had a chance to do a full pass of conversion. - if name in self._targets: - return self._targets[name] - else: - return None + if self.getTarget(name) is not None: + raise RuntimeError(f"Unexpected conversion error for {name}") + return None - def _installInflightFunctions(self, name): + def _installInflightFunctions(self, rootIdentity): + """Add all function definitions corresponding to keys in inflight_function_conversions to the relevant dictionaries.""" if VALIDATE_FUNCTION_DEFINITIONS_STABLE: # this should always be true, but its expensive so we have it off by default for identifier, functionConverter in self._inflight_function_conversions.items(): @@ -915,11 +950,25 @@ def _installInflightFunctions(self, name): finally: self._currentlyConverting = None + # restrict to the set of inflight functions that are reachable from rootName + # we produce copies of functions that we don't actually need to compile during + # early phases of type inference + reachable = self._dependencies.reachableInSet( + rootIdentity, + set(self._inflight_function_conversions) + ) + for identifier, functionConverter in self._inflight_function_conversions.items(): + if identifier not in reachable: + continue outboundTargets = [] for outboundFuncId in self._dependencies.getNamesDependedOn(identifier): name = self._link_name_for_identity[outboundFuncId] - outboundTargets.append(self._targets[name]) + target = self.getTarget(name) + if target is not None: + outboundTargets.append(target) + else: + raise RuntimeError(f'dependency not found for {name}.') nativeFunction, actual_output_type = self._inflight_definitions.get(identifier) @@ -969,3 +1018,6 @@ def _installInflightFunctions(self, name): self._definitions[name] = nativeFunction self._new_native_functions.add(name) + + self._inflight_definitions.clear() + self._inflight_function_conversions.clear() diff --git a/typed_python/compiler/runtime.py b/typed_python/compiler/runtime.py index 8621147a..b27cc19e 100644 --- a/typed_python/compiler/runtime.py +++ b/typed_python/compiler/runtime.py @@ -207,7 +207,7 @@ def __init__(self): ) else: self.compilerCache = None - self.llvm_compiler = llvm_compiler.Compiler(inlineThreshold=100) + self.llvm_compiler = llvm_compiler.Compiler(inlineThreshold=100, compilerCache=self.compilerCache) self.converter = python_to_native_converter.PythonToNativeConverter( self.llvm_compiler, self.compilerCache diff --git a/typed_python/compiler/tests/numpy_interaction_test.py b/typed_python/compiler/tests/numpy_interaction_test.py index f15bfea9..db774c6d 100644 --- a/typed_python/compiler/tests/numpy_interaction_test.py +++ b/typed_python/compiler/tests/numpy_interaction_test.py @@ -1,4 +1,4 @@ -from typed_python import ListOf, Entrypoint +from typed_python import ListOf, Entrypoint, SerializationContext import numpy import numpy.linalg @@ -44,3 +44,12 @@ def test_listof_from_sliced_numpy_array(): y = x[::2] assert ListOf(int)(y) == [0, 2] + + +def test_can_serialize_numpy_ufunc(): + assert numpy.sin == SerializationContext().deserialize(SerializationContext().serialize(numpy.sin)) + + +def test_can_serialize_numpy_array(): + x = numpy.ones(10) + assert (x == SerializationContext().deserialize(SerializationContext().serialize(x))).all() diff --git a/typed_python/compiler/tests/type_of_instances_compilation_test.py b/typed_python/compiler/tests/type_of_instances_compilation_test.py index 337bdc3f..c3fdf459 100644 --- a/typed_python/compiler/tests/type_of_instances_compilation_test.py +++ b/typed_python/compiler/tests/type_of_instances_compilation_test.py @@ -17,13 +17,13 @@ def typeOfArg(x: C): def test_type_of_alternative_is_specific(): for members in [{}, {'a': int}]: - A = Alternative("A", A=members) + Alt = Alternative("Alt", A=members) @Entrypoint - def typeOfArg(x: A): + def typeOfArg(x: Alt): return type(x) - assert typeOfArg(A.A()) is A.A + assert typeOfArg(Alt.A()) is Alt.A def test_type_of_concrete_alternative_is_specific(): diff --git a/typed_python/lib/datetime/date_time.py b/typed_python/lib/datetime/date_time.py index 634a624c..5b189b6f 100644 --- a/typed_python/lib/datetime/date_time.py +++ b/typed_python/lib/datetime/date_time.py @@ -269,7 +269,7 @@ def firstOfMonth(self): return Date(self.year, self.month, 1) def quarterOfYear(self): - return (self.date.month - 1) // 3 + 1 + return (self.month - 1) // 3 + 1 def next(self, step: int = 1): """Returns the date `step` days ahead of `self`. diff --git a/typed_python/lib/datetime/date_time_test.py b/typed_python/lib/datetime/date_time_test.py index 1e7e3f1b..1f14deb9 100644 --- a/typed_python/lib/datetime/date_time_test.py +++ b/typed_python/lib/datetime/date_time_test.py @@ -1,6 +1,7 @@ import pytest import pytz import datetime +from typed_python import NamedTuple from typed_python.lib.datetime.date_time import ( Date, DateTime, @@ -15,8 +16,8 @@ OneFoldOnlyError, PytzTimezone, ) -from typed_python import Timestamp, NamedTuple from typed_python.lib.datetime.date_parser_test import get_datetimes_in_range +from typed_python.lib.timestamp import Timestamp def test_last_weekday_of_month(): diff --git a/typed_python/types_serialization_test.py b/typed_python/types_serialization_test.py index c5f33c2a..e34da54b 100644 --- a/typed_python/types_serialization_test.py +++ b/typed_python/types_serialization_test.py @@ -15,6 +15,8 @@ import sys import os import importlib +from functools import lru_cache + from abc import ABC, abstractmethod, ABCMeta from typed_python.test_util import callFunctionInFreshProcess import typed_python.compiler.python_ast_util as python_ast_util @@ -57,6 +59,13 @@ module_level_testfun = dummy_test_module.testfunction +class GlobalClassWithLruCache: + @staticmethod + @lru_cache(maxsize=None) + def f(x): + return x + + def moduleLevelFunctionUsedByExactlyOneSerializationTest(): return "please don't touch me" @@ -3061,3 +3070,34 @@ def f(self): print(x) # TODO: make this True # assert x[0].f.__closure__[0].cell_contents is x + + def test_serialize_pyobj_with_custom_reduce(self): + class CustomReduceObject: + def __reduce__(self): + return 'CustomReduceObject' + + assert CustomReduceObject == SerializationContext().deserialize(SerializationContext().serialize(CustomReduceObject)) + + def test_serialize_pyobj_in_MRTG_with_custom_reduce(self): + def getX(): + class InnerCustomReduceObject: + def __reduce__(self): + return 'InnerCustomReduceObject' + + def f(self): + return x + + x = (InnerCustomReduceObject, InnerCustomReduceObject) + + return x + + x = callFunctionInFreshProcess(getX, (), showStdout=True) + + assert x == SerializationContext().deserialize(SerializationContext().serialize(x)) + + def test_serialize_class_static_lru_cache(self): + s = SerializationContext() + + assert ( + s.deserialize(s.serialize(GlobalClassWithLruCache.f)) is GlobalClassWithLruCache.f + )