From 7717948b37406eded98fe09f1bef9ef567d72536 Mon Sep 17 00:00:00 2001 From: Augustus Lonergan Date: Tue, 10 Jan 2023 22:23:05 +0000 Subject: [PATCH 1/7] Limit compiler cache interface to getTarget and addModule Current implementation of the interface, used via _loadFromCompilerCache, runs loadForSymbol for a given link name and then returns two dicts representing everything the cache touched during this load process. Maintaining this interface makes the partial-load refactor difficult, and muddies the converter layers/cache relationship. Here we alter the API so that we can ask the cache for a TypedCallTarget, and add modules, and that's it. This means getting rid of _loadFromCompilerCache, and associated registers for tracking what's being converted. Also means passing the cache down to the native_ast_to_llvm layer. --- typed_python/compiler/compiler_cache.py | 28 ++-- typed_python/compiler/llvm_compiler.py | 8 +- typed_python/compiler/native_ast_to_llvm.py | 60 ++++---- .../compiler/python_to_native_converter.py | 130 +++++++++--------- typed_python/compiler/runtime.py | 2 +- 5 files changed, 108 insertions(+), 120 deletions(-) diff --git a/typed_python/compiler/compiler_cache.py b/typed_python/compiler/compiler_cache.py index a093fc706..0e9f2063a 100644 --- a/typed_python/compiler/compiler_cache.py +++ b/typed_python/compiler/compiler_cache.py @@ -62,9 +62,18 @@ def __init__(self, cacheDir): if len(moduleHash) == 40: self.loadNameManifestFromStoredModuleByHash(moduleHash) + self.targetsLoaded = {} + def hasSymbol(self, linkName): return linkName in self.nameToModuleHash + def getTarget(self, linkName): + assert self.hasSymbol(linkName) + + self.loadForSymbol(linkName) + + return self.targetsLoaded[linkName] + def markModuleHashInvalid(self, hashstr): with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"): pass @@ -72,15 +81,9 @@ def markModuleHashInvalid(self, hashstr): def loadForSymbol(self, linkName): moduleHash = self.nameToModuleHash[linkName] - nameToTypedCallTarget = {} - nameToNativeFunctionType = {} - - if not self.loadModuleByHash(moduleHash, nameToTypedCallTarget, nameToNativeFunctionType): - return None - - return nameToTypedCallTarget, nameToNativeFunctionType + self.loadModuleByHash(moduleHash) - def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFunctionType): + def loadModuleByHash(self, moduleHash): """Load a module by name. As we load, place all the newly imported typed call targets into @@ -114,11 +117,7 @@ def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFuncti # load the submodules first for submodule in submodules: - if not self.loadModuleByHash( - submodule, - nameToTypedCallTarget, - nameToNativeFunctionType - ): + if not self.loadModuleByHash(submodule): return False modulePath = os.path.join(targetDir, "module.so") @@ -131,8 +130,7 @@ def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFuncti self.loadedModules[moduleHash] = loaded - nameToTypedCallTarget.update(callTargets) - nameToNativeFunctionType.update(functionNameToNativeType) + self.targetsLoaded.update(callTargets) return True diff --git a/typed_python/compiler/llvm_compiler.py b/typed_python/compiler/llvm_compiler.py index 9579df168..4e019c032 100644 --- a/typed_python/compiler/llvm_compiler.py +++ b/typed_python/compiler/llvm_compiler.py @@ -84,18 +84,14 @@ def create_execution_engine(inlineThreshold): class Compiler: - def __init__(self, inlineThreshold): + def __init__(self, inlineThreshold, compilerCache): self.engine, self.module_pass_manager = create_execution_engine(inlineThreshold) - self.converter = native_ast_to_llvm.Converter() + self.converter = native_ast_to_llvm.Converter(compilerCache) self.functions_by_name = {} self.inlineThreshold = inlineThreshold self.verbose = False self.optimize = True - def markExternal(self, functionNameToType): - """Provide type signatures for a set of external functions.""" - self.converter.markExternal(functionNameToType) - def mark_converter_verbose(self): self.converter.verbose = True diff --git a/typed_python/compiler/native_ast_to_llvm.py b/typed_python/compiler/native_ast_to_llvm.py index 4850ed959..769b58f62 100644 --- a/typed_python/compiler/native_ast_to_llvm.py +++ b/typed_python/compiler/native_ast_to_llvm.py @@ -501,14 +501,13 @@ def __init__(self, module, globalDefinitions, globalDefinitionLlvmValues, - function, converter, builder, arg_assignments, output_type, - external_function_references + external_function_references, + compilerCache, ): - self.function = function # dict from name to GlobalVariableDefinition self.globalDefinitions = globalDefinitions @@ -522,6 +521,7 @@ def __init__(self, self.external_function_references = external_function_references self.tags_initialized = {} self.stack_slots = {} + self.compilerCache = compilerCache def tags_as(self, new_tags): class scoper(): @@ -648,7 +648,24 @@ def namedCallTargetToLLVM(self, target): llvmlite.ir.Function(self.module, func_type, target.name) func = self.external_function_references[target.name] - elif target.name in self.converter._externallyDefinedFunctionTypes: + elif target.name in self.converter._function_definitions: + func = self.converter._functions_by_name[target.name] + + if func.module is not self.module: + # first, see if we'd like to inline this module + if ( + self.converter.totalFunctionComplexity(target.name) < CROSS_MODULE_INLINE_COMPLEXITY + ): + func = self.converter.repeatFunctionInModule(target.name, self.module) + else: + if target.name not in self.external_function_references: + self.external_function_references[target.name] = \ + llvmlite.ir.Function(self.module, func.function_type, func.name) + + func = self.external_function_references[target.name] + else: + # TODO: decide whether to inline based on something in the compiler cache + assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name) # this function is defined in a shared object that we've loaded from a prior # invocation if target.name not in self.external_function_references: @@ -665,22 +682,6 @@ def namedCallTargetToLLVM(self, target): ) func = self.external_function_references[target.name] - else: - func = self.converter._functions_by_name[target.name] - - if func.module is not self.module: - # first, see if we'd like to inline this module - if ( - self.converter.totalFunctionComplexity(target.name) < CROSS_MODULE_INLINE_COMPLEXITY - and self.converter.canBeInlined(target.name) - ): - func = self.converter.repeatFunctionInModule(target.name, self.module) - else: - if target.name not in self.external_function_references: - self.external_function_references[target.name] = \ - llvmlite.ir.Function(self.module, func.function_type, func.name) - - func = self.external_function_references[target.name] return TypedLLVMValue( func, @@ -1484,16 +1485,12 @@ def define(fname, output, inputs, vararg=False): class Converter: - def __init__(self): + def __init__(self, compilerCache): object.__init__(self) self._modules = {} self._functions_by_name = {} self._function_definitions = {} - # a map from function name to function type for functions that - # are defined in external shared objects and linked in to this one. - self._externallyDefinedFunctionTypes = {} - # total number of instructions in each function, by name self._function_complexity = {} @@ -1502,12 +1499,7 @@ def __init__(self): self._printAllNativeCalls = os.getenv("TP_COMPILER_LOG_NATIVE_CALLS") self.verbose = False - def markExternal(self, functionNameToType): - """Provide type signatures for a set of external functions.""" - self._externallyDefinedFunctionTypes.update(functionNameToType) - - def canBeInlined(self, name): - return name not in self._externallyDefinedFunctionTypes + self.compilerCache = compilerCache def totalFunctionComplexity(self, name): """Return the total number of instructions contained in a function. @@ -1621,19 +1613,19 @@ def add_functions(self, names_to_definitions): TypedLLVMValue(func.args[i], definition.args[i][1]) block = func.append_basic_block('entry') - builder = llvmlite.ir.IRBuilder(block) + builder = llvmlite.ir.IRBuilder(block) # this shares state with func try: func_converter = FunctionConverter( module, globalDefinitions, globalDefinitionsLlvmValues, - func, self, builder, arg_assignments, definition.output_type, - external_function_references + external_function_references, + self.compilerCache, ) func_converter.setup() diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index c9bb2748a..6c23a92a8 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -72,19 +72,21 @@ def getNextDirtyNode(self): return identity - def addRoot(self, identity): + def addRoot(self, identity, dirty=True): if identity not in self._identity_levels: self._identity_levels[identity] = 0 - self.markDirty(identity) + if dirty: + self.markDirty(identity) - def addEdge(self, caller, callee): + def addEdge(self, caller, callee, dirty=True): if caller not in self._identity_levels: raise Exception(f"unknown identity {caller} found in the graph") if callee not in self._identity_levels: self._identity_levels[callee] = self._identity_levels[caller] + 1 - self.markDirty(callee, isNew=True) + if dirty: + self.markDirty(callee, isNew=True) self._dependencies.addEdge(caller, callee) @@ -125,12 +127,6 @@ def __init__(self, llvmCompiler, compilerCache): # if True, then insert additional code to check for undefined behavior. self.generateDebugChecks = False - # all link names for which we have a definition. - self._allDefinedNames = set() - - # all names we loaded from the cache - self._allCachedNames = set() - self._link_name_for_identity = {} self._identity_for_link_name = {} self._definitions = {} @@ -211,7 +207,7 @@ def buildAndLinkNewModule(self): self.compilerCache.addModule( binary, - {name: self._targets[name] for name in targets if name in self._targets}, + {name: self.getTarget(name) for name in targets if self.hasTarget(name)}, externallyUsed ) @@ -274,28 +270,23 @@ def defineLinkName(self, identity, linkName): self._link_name_for_identity[identity] = linkName self._identity_for_link_name[linkName] = identity - if linkName in self._allDefinedNames: - return False - - self._allDefinedNames.add(linkName) - - self._loadFromCompilerCache(linkName) + def hasTarget(self, linkName): + return self.getTarget(linkName) is not None - return True + def deleteTarget(self, linkName): + self._targets.pop(linkName) - def _loadFromCompilerCache(self, linkName): - if self.compilerCache: - if self.compilerCache.hasSymbol(linkName): - callTargetsAndTypes = self.compilerCache.loadForSymbol(linkName) + def setTarget(self, linkName, target): + self._targets[linkName] = target - if callTargetsAndTypes is not None: - newTypedCallTargets, newNativeFunctionTypes = callTargetsAndTypes + def getTarget(self, linkName): + if linkName in self._targets: + return self._targets[linkName] - self._targets.update(newTypedCallTargets) - self.llvmCompiler.markExternal(newNativeFunctionTypes) + if self.compilerCache.hasSymbol(linkName): + return self.compilerCache.getTarget(linkName) - self._allDefinedNames.update(newNativeFunctionTypes) - self._allCachedNames.update(newNativeFunctionTypes) + return None def defineNonPythonFunction(self, name, identityTuple, context): """Define a non-python generating function (if we haven't defined it before already) @@ -311,23 +302,28 @@ def defineNonPythonFunction(self, name, identityTuple, context): self.defineLinkName(identity, linkName) + target = self.getTarget(linkName) + if self._currentlyConverting is not None: - self._dependencies.addEdge(self._currentlyConverting, identity) + self._dependencies.addEdge(self._currentlyConverting, identity, dirty=(target is None)) else: - self._dependencies.addRoot(identity) + self._dependencies.addRoot(identity, dirty=(target is None)) - if linkName in self._targets: - return self._targets.get(linkName) + if target is not None: + return target self._inflight_function_conversions[identity] = context if context.knownOutputType() is not None or context.alwaysRaises(): - self._targets[linkName] = self.getTypedCallTarget( - name, - context.getInputTypes(), - context.knownOutputType(), - alwaysRaises=context.alwaysRaises(), - functionMetadata=context.functionMetadata + self.setTarget( + linkName, + self.getTypedCallTarget( + name, + context.getInputTypes(), + context.knownOutputType(), + alwaysRaises=context.alwaysRaises(), + functionMetadata=context.functionMetadata, + ) ) if self._currentlyConverting is None: @@ -336,7 +332,7 @@ def defineNonPythonFunction(self, name, identityTuple, context): self._installInflightFunctions(name) self._inflight_function_conversions.clear() - return self._targets.get(linkName) + return self.getTarget(linkName) def defineNativeFunction(self, name, identity, input_types, output_type, generatingFunction): """Define a native function if we haven't defined it before already. @@ -459,13 +455,19 @@ def generateCallConverter(self, callTarget: TypedCallTarget): identifier = "call_converter_" + callTarget.name linkName = callTarget.name + ".dispatch" - if linkName in self._allDefinedNames: + # # we already made a definition for this in this process so don't do it again + if linkName in self._definitions: return linkName - self._loadFromCompilerCache(linkName) - if linkName in self._allDefinedNames: + # # we already defined it in another process so don't do it again + if self.compilerCache.hasSymbol(linkName): return linkName + # N.B. there aren't targets for call converters. We make the definition directly. + + # if self.getTarget(linkName): + # return linkName + args = [] for i in range(len(callTarget.input_types)): if not callTarget.input_types[i].is_empty: @@ -503,7 +505,6 @@ def generateCallConverter(self, callTarget: TypedCallTarget): self._link_name_for_identity[identifier] = linkName self._identity_for_link_name[linkName] = identifier - self._allDefinedNames.add(linkName) self._definitions[linkName] = definition self._new_native_functions.add(linkName) @@ -516,10 +517,6 @@ def _resolveAllInflightFunctions(self): if not identity: return - linkName = self._link_name_for_identity[identity] - if linkName in self._allCachedNames: - continue - functionConverter = self._inflight_function_conversions[identity] hasDefinitionBeforeConversion = identity in self._inflight_definitions @@ -529,6 +526,8 @@ def _resolveAllInflightFunctions(self): self._times_calculated[identity] = self._times_calculated.get(identity, 0) + 1 + # this calls back into convert with dependencies + # they get registered as dirty nativeFunction, actual_output_type = functionConverter.convertToNativeFunction() if nativeFunction is not None: @@ -537,9 +536,8 @@ def _resolveAllInflightFunctions(self): for i in self._inflight_function_conversions: if i in self._link_name_for_identity: name = self._link_name_for_identity[i] - if name in self._targets: - self._targets.pop(name) - self._allDefinedNames.discard(name) + if self.hasTarget(name): + self.deleteTarget(name) ln = self._link_name_for_identity.pop(i) self._identity_for_link_name.pop(ln) @@ -567,12 +565,15 @@ def _resolveAllInflightFunctions(self): name = self._link_name_for_identity[identity] - self._targets[name] = self.getTypedCallTarget( + self.setTarget( name, - functionConverter._input_types, - actual_output_type, - alwaysRaises=functionConverter.alwaysRaises(), - functionMetadata=functionConverter.functionMetadata + self.getTypedCallTarget( + name, + functionConverter._input_types, + actual_output_type, + alwaysRaises=functionConverter.alwaysRaises(), + functionMetadata=functionConverter.functionMetadata, + ), ) if dirtyUpstream: @@ -860,13 +861,15 @@ def convert( if assertIsRoot: assert isRoot + target = self.getTarget(name) + if self._currentlyConverting is not None: - self._dependencies.addEdge(self._currentlyConverting, identity) + self._dependencies.addEdge(self._currentlyConverting, identity, dirty=(target is None)) else: - self._dependencies.addRoot(identity) + self._dependencies.addRoot(identity, dirty=(target is None)) - if name in self._targets: - return self._targets[name] + if target is not None: + return target if identity not in self._inflight_function_conversions: functionConverter = self.createConversionContext( @@ -887,7 +890,7 @@ def convert( try: self._resolveAllInflightFunctions() self._installInflightFunctions(name) - return self._targets[name] + return self.getTarget(name) finally: self._inflight_function_conversions.clear() @@ -897,10 +900,9 @@ def convert( # target with an output type and we can return that. Otherwise we have to # return None, which will cause callers to replace this with a throw # until we have had a chance to do a full pass of conversion. - if name in self._targets: - return self._targets[name] - else: - return None + if self.getTarget(name) is not None: + raise Exception("This code looks unreachable to me...") + return None def _installInflightFunctions(self, name): if VALIDATE_FUNCTION_DEFINITIONS_STABLE: diff --git a/typed_python/compiler/runtime.py b/typed_python/compiler/runtime.py index 8621147a0..b27cc19ec 100644 --- a/typed_python/compiler/runtime.py +++ b/typed_python/compiler/runtime.py @@ -207,7 +207,7 @@ def __init__(self): ) else: self.compilerCache = None - self.llvm_compiler = llvm_compiler.Compiler(inlineThreshold=100) + self.llvm_compiler = llvm_compiler.Compiler(inlineThreshold=100, compilerCache=self.compilerCache) self.converter = python_to_native_converter.PythonToNativeConverter( self.llvm_compiler, self.compilerCache From 27827043b11af5c5f3d49d3c308800bef9b9759c Mon Sep 17 00:00:00 2001 From: William Grant Date: Wed, 11 Jan 2023 19:06:40 -0500 Subject: [PATCH 2/7] Allow for partial module loads in compiler cache. Previously we would always attempt to link and validate all global variables when loading a module from the cache. This caused linking errors, or validation errors, or deserialization errors, and meant we needed the mark_invalid mechanism for handling modules with outdated global variables. Here we add double-serialised global variables, and only deserialize,link&validate the subset required for the function required (and its dependencies). This requires the cache to store a function and global_var dependency graph. Also add utility methods for GlobalVariableDefinition. --- typed_python/compiler/binary_shared_object.py | 25 +-- typed_python/compiler/compiler_cache.py | 176 +++++++++++------- typed_python/compiler/compiler_cache_test.py | 45 ++++- .../compiler/global_variable_definition.py | 10 + typed_python/compiler/llvm_compiler.py | 12 +- typed_python/compiler/llvm_compiler_test.py | 2 +- typed_python/compiler/loaded_module.py | 78 ++++---- typed_python/compiler/module_definition.py | 14 +- typed_python/compiler/native_ast_to_llvm.py | 46 +++-- .../compiler/python_to_native_converter.py | 29 +-- 10 files changed, 270 insertions(+), 167 deletions(-) diff --git a/typed_python/compiler/binary_shared_object.py b/typed_python/compiler/binary_shared_object.py index 900892158..5c2e57653 100644 --- a/typed_python/compiler/binary_shared_object.py +++ b/typed_python/compiler/binary_shared_object.py @@ -26,8 +26,8 @@ class LoadedBinarySharedObject(LoadedModule): - def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariableDefinitions): - super().__init__(functionPointers, globalVariableDefinitions) + def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlobalVariableDefinitions): + super().__init__(functionPointers, serializedGlobalVariableDefinitions) self.binarySharedObject = binarySharedObject self.diskPath = diskPath @@ -36,15 +36,17 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariabl class BinarySharedObject: """Models a shared object library (.so) loadable on linux systems.""" - def __init__(self, binaryForm, functionTypes, globalVariableDefinitions): + def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies): """ Args: - binaryForm - a bytes object containing the actual compiled code for the module - globalVariableDefinitions - a map from name to GlobalVariableDefinition + binaryForm: a bytes object containing the actual compiled code for the module + serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition + globalDependencies: a dict from function linkname to the list of global variables it depends on """ self.binaryForm = binaryForm self.functionTypes = functionTypes - self.globalVariableDefinitions = globalVariableDefinitions + self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions + self.globalDependencies = globalDependencies self.hash = sha_hash(binaryForm) @property @@ -52,14 +54,14 @@ def definedSymbols(self): return self.functionTypes.keys() @staticmethod - def fromDisk(path, globalVariableDefinitions, functionNameToType): + def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): with open(path, "rb") as f: binaryForm = f.read() - return BinarySharedObject(binaryForm, functionNameToType, globalVariableDefinitions) + return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) @staticmethod - def fromModule(module, globalVariableDefinitions, functionNameToType): + def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): target_triple = llvm.get_process_triple() target = llvm.Target.from_triple(target_triple) target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default') @@ -80,7 +82,7 @@ def fromModule(module, globalVariableDefinitions, functionNameToType): ) with open(os.path.join(tf, "module.so"), "rb") as so_file: - return BinarySharedObject(so_file.read(), functionNameToType, globalVariableDefinitions) + return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) def load(self, storageDir): """Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer""" @@ -127,8 +129,7 @@ def loadFromPath(self, modulePath): self, modulePath, functionPointers, - self.globalVariableDefinitions + self.serializedGlobalVariableDefinitions ) - loadedModule.linkGlobalVariables() return loadedModule diff --git a/typed_python/compiler/compiler_cache.py b/typed_python/compiler/compiler_cache.py index 0e9f2063a..26f7e2a29 100644 --- a/typed_python/compiler/compiler_cache.py +++ b/typed_python/compiler/compiler_cache.py @@ -15,9 +15,12 @@ import os import uuid import shutil -from typed_python.compiler.loaded_module import LoadedModule -from typed_python.compiler.binary_shared_object import BinarySharedObject +from typing import Optional, List + +from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject +from typed_python.compiler.directed_graph import DirectedGraph +from typed_python.compiler.typed_call_target import TypedCallTarget from typed_python.SerializationContext import SerializationContext from typed_python import Dict, ListOf @@ -52,146 +55,173 @@ def __init__(self, cacheDir): ensureDirExists(cacheDir) - self.loadedModules = Dict(str, LoadedModule)() + self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)() self.nameToModuleHash = Dict(str, str)() - self.modulesMarkedValid = set() - self.modulesMarkedInvalid = set() + self.moduleManifestsLoaded = set() for moduleHash in os.listdir(self.cacheDir): if len(moduleHash) == 40: self.loadNameManifestFromStoredModuleByHash(moduleHash) - self.targetsLoaded = {} + # the set of functions with an associated module in loadedBinarySharedObjects + self.targetsLoaded: Dict[str, TypedCallTarget] = {} - def hasSymbol(self, linkName): - return linkName in self.nameToModuleHash + # the set of functions with linked and validated globals (i.e. ready to be run). + self.targetsValidated = set() - def getTarget(self, linkName): - assert self.hasSymbol(linkName) + self.function_dependency_graph = DirectedGraph() + # dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions) + self.global_dependencies = Dict(str, ListOf(str))() - self.loadForSymbol(linkName) + def hasSymbol(self, linkName: str) -> bool: + """NB this will return True even if the linkName is ultimately unretrievable.""" + return linkName in self.nameToModuleHash + def getTarget(self, linkName: str) -> TypedCallTarget: + if not self.hasSymbol(linkName): + raise ValueError(f'symbol not found for linkName {linkName}') + self.loadForSymbol(linkName) return self.targetsLoaded[linkName] - def markModuleHashInvalid(self, hashstr): - with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"): - pass + def dependencies(self, linkName: str) -> Optional[List[str]]: + """Returns all the function names that `linkName` depends on""" + return list(self.function_dependency_graph.outgoing(linkName)) - def loadForSymbol(self, linkName): + def loadForSymbol(self, linkName: str) -> None: + """Loads the whole module, and any submodules, into LoadedBinarySharedObjects""" moduleHash = self.nameToModuleHash[linkName] self.loadModuleByHash(moduleHash) - def loadModuleByHash(self, moduleHash): + if linkName not in self.targetsValidated: + dependantFuncs = self.dependencies(linkName) + [linkName] + globalsToLink = {} # dict from modulehash to list of globals. + for funcName in dependantFuncs: + if funcName not in self.targetsValidated: + funcModuleHash = self.nameToModuleHash[funcName] + # append to the list of globals to link for a given module. TODO: optimise this, don't double-link. + globalsToLink[funcModuleHash] = globalsToLink.get(funcModuleHash, []) + self.global_dependencies.get(funcName, []) + + for moduleHash, globs in globalsToLink.items(): # this works because loadModuleByHash loads submodules too. + if globs: + definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x] + for x in globs + } + self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink) + if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink): + raise RuntimeError('failed to validate globals when loading:', linkName) + + self.targetsValidated.update(dependantFuncs) + + def loadModuleByHash(self, moduleHash: str) -> None: """Load a module by name. As we load, place all the newly imported typed call targets into 'nameToTypedCallTarget' so that the rest of the system knows what functions have been uncovered. """ - if moduleHash in self.loadedModules: - return True + if moduleHash in self.loadedBinarySharedObjects: + return targetDir = os.path.join(self.cacheDir, moduleHash) - try: - with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: - callTargets = SerializationContext().deserialize(f.read()) + # TODO (Will) - store these names as module consts, use one .dat only + with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: + callTargets = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: - globalVarDefs = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: + serializedGlobalVarDefs = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: - functionNameToNativeType = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: + functionNameToNativeType = SerializationContext().deserialize(f.read()) + + with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: + submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: - submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - except Exception: - self.markModuleHashInvalid(moduleHash) - return False + with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f: + dependency_edgelist = SerializationContext().deserialize(f.read()) - if not LoadedModule.validateGlobalVariables(globalVarDefs): - self.markModuleHashInvalid(moduleHash) - return False + with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f: + globalDependencies = SerializationContext().deserialize(f.read()) # load the submodules first for submodule in submodules: - if not self.loadModuleByHash(submodule): - return False + self.loadModuleByHash(submodule) modulePath = os.path.join(targetDir, "module.so") loaded = BinarySharedObject.fromDisk( modulePath, - globalVarDefs, - functionNameToNativeType + serializedGlobalVarDefs, + functionNameToNativeType, + globalDependencies + ).loadFromPath(modulePath) - self.loadedModules[moduleHash] = loaded + self.loadedBinarySharedObjects[moduleHash] = loaded self.targetsLoaded.update(callTargets) - return True + assert not any(key in self.global_dependencies for key in globalDependencies) # should only happen if there's a hash collision. + self.global_dependencies.update(globalDependencies) + + # update the cache's dependency graph with our new edges. + for function_name, dependant_function_name in dependency_edgelist: + self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name) - def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies): + def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist): """Add new code to the compiler cache. Args: - binarySharedObject - a BinarySharedObject containing the actual assembler - we've compiled - nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us - the formal python types for all the objects - linkDependencies - a set of linknames we depend on directly. + binarySharedObject: a BinarySharedObject containing the actual assembler + we've compiled. + nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us + the formal python types for all the objects. + linkDependencies: a set of linknames we depend on directly. + dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the + module. """ dependentHashes = set() for name in linkDependencies: dependentHashes.add(self.nameToModuleHash[name]) - path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes) + path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes, dependencyEdgelist) - self.loadedModules[hashToUse] = ( + self.loadedBinarySharedObjects[hashToUse] = ( binarySharedObject.loadFromPath(os.path.join(path, "module.so")) ) for n in binarySharedObject.definedSymbols: self.nameToModuleHash[n] = hashToUse - def loadNameManifestFromStoredModuleByHash(self, moduleHash): - if moduleHash in self.modulesMarkedValid: - return True - - targetDir = os.path.join(self.cacheDir, moduleHash) + # link & validate all globals for the new module + self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables() + if not self.loadedBinarySharedObjects[hashToUse].validateGlobalVariables( + self.loadedBinarySharedObjects[hashToUse].serializedGlobalVariableDefinitions): + raise RuntimeError('failed to validate globals in new module:', hashToUse) - # ignore 'marked invalid' - if os.path.exists(os.path.join(targetDir, "marked_invalid")): - # just bail - don't try to read it now + def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None: + if moduleHash in self.moduleManifestsLoaded: + return - # for the moment, we don't try to clean up the cache, because - # we can't be sure that some process is not still reading the - # old files. - self.modulesMarkedInvalid.add(moduleHash) - return False + targetDir = os.path.join(self.cacheDir, moduleHash) with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: submodules = SerializationContext().deserialize(f.read(), ListOf(str)) for subHash in submodules: - if not self.loadNameManifestFromStoredModuleByHash(subHash): - self.markModuleHashInvalid(subHash) - return False + self.loadNameManifestFromStoredModuleByHash(subHash) with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f: self.nameToModuleHash.update( SerializationContext().deserialize(f.read(), Dict(str, str)) ) - self.modulesMarkedValid.add(moduleHash) - - return True + self.moduleManifestsLoaded.add(moduleHash) - def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules): + def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules, dependencyEdgelist): """Write out a disk representation of this module. This includes writing both the shared object, a manifest of the function names @@ -244,11 +274,17 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule # write the type manifest with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f: - f.write(SerializationContext().serialize(binarySharedObject.globalVariableDefinitions)) + f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions)) with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f: f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str))) + with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f: + f.write(SerializationContext().serialize(dependencyEdgelist)) # might need a listof + + with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.globalDependencies)) + try: os.rename(tempTargetDir, targetDir) except IOError: @@ -264,7 +300,7 @@ def function_pointer_by_name(self, linkName): if moduleHash is None: raise Exception("Can't find a module for " + linkName) - if moduleHash not in self.loadedModules: + if moduleHash not in self.loadedBinarySharedObjects: self.loadForSymbol(linkName) - return self.loadedModules[moduleHash].functionPointers[linkName] + return self.loadedBinarySharedObjects[moduleHash].functionPointers[linkName] diff --git a/typed_python/compiler/compiler_cache_test.py b/typed_python/compiler/compiler_cache_test.py index 81ad2f12f..4d35fbc33 100644 --- a/typed_python/compiler/compiler_cache_test.py +++ b/typed_python/compiler/compiler_cache_test.py @@ -119,6 +119,7 @@ def test_compiler_cache_understands_type_changes(): VERSION1 = {'x.py': xmodule, 'y.py': ymodule} VERSION2 = {'x.py': xmodule.replace("1: 2", "1: 3"), 'y.py': ymodule} VERSION3 = {'x.py': xmodule.replace("int, int", "int, float").replace('1: 2', '1: 2.5'), 'y.py': ymodule} + VERSION4 = {'x.py': xmodule.replace("1: 2", "1: 4"), 'y.py': ymodule} assert '1: 3' in VERSION2['x.py'] @@ -134,6 +135,10 @@ def test_compiler_cache_understands_type_changes(): assert evaluateExprInFreshProcess(VERSION3, 'y.g(1)', compilerCacheDir) == 2.5 assert len(os.listdir(compilerCacheDir)) == 2 + # use the previously compiled module + assert evaluateExprInFreshProcess(VERSION4, 'y.g(1)', compilerCacheDir) == 4 + assert len(os.listdir(compilerCacheDir)) == 2 + @pytest.mark.skipif('sys.platform=="darwin"') def test_compiler_cache_handles_exceptions_properly(): @@ -362,12 +367,9 @@ def test_compiler_cache_handles_changed_types(): assert evaluateExprInFreshProcess(VERSION2, 'x.f(1)', compilerCacheDir) == 1 assert len(os.listdir(compilerCacheDir)) == 2 - badCt = 0 - for subdir in os.listdir(compilerCacheDir): - if 'marked_invalid' in os.listdir(os.path.join(compilerCacheDir, subdir)): - badCt += 1 - - assert badCt == 1 + # if we then use g1 again, it should not have been marked invalid and so remains accessible. + assert evaluateExprInFreshProcess(VERSION1, 'x.g1(1)', compilerCacheDir) == 1 + assert len(os.listdir(compilerCacheDir)) == 2 @pytest.mark.skipif('sys.platform=="darwin"') @@ -395,3 +397,34 @@ def test_ordering_is_stable_under_code_change(): ) assert names == names2 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_avoids_deserialization_error(): + xmodule1 = "\n".join([ + "@Entrypoint", + "def f():", + " return None", + "import badModule", + "@Entrypoint", + "def g():", + " print(badModule)", + " return f()", + ]) + + xmodule2 = "\n".join([ + "@Entrypoint", + "def f():", + " return", + ]) + + VERSION1 = {'x.py': xmodule1, 'badModule.py': ''} + VERSION2 = {'x.py': xmodule2} + + with tempfile.TemporaryDirectory() as compilerCacheDir: + evaluateExprInFreshProcess(VERSION1, 'x.g()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 1 + evaluateExprInFreshProcess(VERSION2, 'x.f()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 2 + evaluateExprInFreshProcess(VERSION1, 'x.g()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 2 diff --git a/typed_python/compiler/global_variable_definition.py b/typed_python/compiler/global_variable_definition.py index 4f01c11f0..b95dca82b 100644 --- a/typed_python/compiler/global_variable_definition.py +++ b/typed_python/compiler/global_variable_definition.py @@ -79,3 +79,13 @@ def __init__(self, name, typ, metadata): self.name = name self.type = typ self.metadata = metadata + + def __eq__(self, other): + if not isinstance(other, GlobalVariableDefinition): + return False + + return self.name == other.name and self.type == other.type and self.metadata == other.metadata + + def __str__(self): + metadata_str = str(self.metadata) if len(str(self.metadata)) < 100 else str(self.metadata)[:100] + "..." + return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={metadata_str})" diff --git a/typed_python/compiler/llvm_compiler.py b/typed_python/compiler/llvm_compiler.py index 4e019c032..f33e5edb3 100644 --- a/typed_python/compiler/llvm_compiler.py +++ b/typed_python/compiler/llvm_compiler.py @@ -22,7 +22,7 @@ from typed_python.compiler.binary_shared_object import BinarySharedObject import ctypes -from typed_python import _types +from typed_python import _types, SerializationContext llvm.initialize() llvm.initialize_native_target() @@ -117,17 +117,20 @@ def buildSharedObject(self, functions): self.engine.finalize_object() + serializedGlobalVariableDefinitions = {x: SerializationContext().serialize(y) for x, y in module.globalVariableDefinitions.items()} + return BinarySharedObject.fromModule( mod, - module.globalVariableDefinitions, + serializedGlobalVariableDefinitions, module.functionNameToType, + module.globalDependencies ) def function_pointer_by_name(self, name): return self.functions_by_name.get(name) def buildModule(self, functions): - """Compile a list of functions into a new module. + """Compile a list of functions into a new module. Only relevant if there is no compiler cache. Args: functions - a map from name to native_ast.Function @@ -183,4 +186,5 @@ def buildModule(self, functions): ) ) - return LoadedModule(native_function_pointers, module.globalVariableDefinitions) + serializedGlobalVariableDefinitions = {x: SerializationContext().serialize(y) for x, y in module.globalVariableDefinitions.items()} + return LoadedModule(native_function_pointers, serializedGlobalVariableDefinitions) diff --git a/typed_python/compiler/llvm_compiler_test.py b/typed_python/compiler/llvm_compiler_test.py index e10f94533..fb1911e00 100644 --- a/typed_python/compiler/llvm_compiler_test.py +++ b/typed_python/compiler/llvm_compiler_test.py @@ -115,7 +115,7 @@ def test_create_binary_shared_object(): {'__test_f_2': f} ) - assert len(bso.globalVariableDefinitions) == 1 + assert len(bso.serializedGlobalVariableDefinitions) == 1 with tempfile.TemporaryDirectory() as tf: loaded = bso.load(tf) diff --git a/typed_python/compiler/loaded_module.py b/typed_python/compiler/loaded_module.py index c03ab3211..01b9ec814 100644 --- a/typed_python/compiler/loaded_module.py +++ b/typed_python/compiler/loaded_module.py @@ -1,39 +1,46 @@ +from typing import Dict, List from typed_python.compiler.module_definition import ModuleDefinition -from typed_python import PointerTo, ListOf, Class +from typed_python import PointerTo, ListOf, Class, SerializationContext from typed_python import _types class LoadedModule: """Represents a bundle of compiled functions that are now loaded in memory. - Members: functionPointers - a map from name to NativeFunctionPointer giving the public interface of the module - globalVariableDefinitions - a map from name to GlobalVariableDefinition + serializedGlobalVariableDefinitions - a map from LLVM-assigned global name to serialized GlobalVariableDefinition giving the loadable strings """ GET_GLOBAL_VARIABLES_NAME = ModuleDefinition.GET_GLOBAL_VARIABLES_NAME - def __init__(self, functionPointers, globalVariableDefinitions): + def __init__(self, functionPointers, serializedGlobalVariableDefinitions): self.functionPointers = functionPointers + assert ModuleDefinition.GET_GLOBAL_VARIABLES_NAME in self.functionPointers + + self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions + self.orderedDefs = [ + self.serializedGlobalVariableDefinitions[name] for name in sorted(self.serializedGlobalVariableDefinitions) + ] + self.orderedDefNames = sorted(list(self.serializedGlobalVariableDefinitions.keys())) + self.pointers = ListOf(PointerTo(int))() + self.pointers.resize(len(self.orderedDefs)) - self.globalVariableDefinitions = globalVariableDefinitions + self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](self.pointers.pointerUnsafe(0)) @staticmethod - def validateGlobalVariables(globalVariableDefinitions): + def validateGlobalVariables(serializedGlobalVariableDefinitions: Dict[str, bytes]) -> bool: """Check that each global variable definition is sensible. - Sometimes we may successfully deserialize a global variable from a cached module, but then some dictionary member is not valid because it was removed or has the wrong type. In this case, we need to evict this module from the cache because it's no longer valid. Args: - globalVariableDefinitions - a dict from string to GlobalVariableMetadata + serializedGlobalVariableDefinitions: a dict from string to a serialized GlobalVariableMetadata """ - for gvd in globalVariableDefinitions.values(): - meta = gvd.metadata - + for gvd in serializedGlobalVariableDefinitions.values(): + meta = SerializationContext().deserialize(gvd).metadata if meta.matches.PointerToTypedPythonObjectAsMemberOfDict: if not isinstance(meta.sourceDict, dict): return False @@ -54,54 +61,45 @@ def validateGlobalVariables(globalVariableDefinitions): return True - def linkGlobalVariables(self): - """Walk over all global variables in the module and make sure they are populated. - + def linkGlobalVariables(self, variable_names: List[str] = None) -> None: + """Walk over all global variables in `variable_names` and make sure they are populated. Each module has a bunch of global variables that contain references to things like type objects, string objects, python module members, etc. - - The metadata about these is stored in 'self.globalVariableDefinitions' whose keys + The metadata about these is stored in 'self.serializedGlobalVariableDefinitions' whose keys are names and whose values are GlobalVariableMetadata instances. - Every module we compile exposes a member named ModuleDefinition.GET_GLOBAL_VARIABLES_NAME which takes a pointer to a list of pointers and fills it out with the global variables. - When the module is loaded, all the variables are initialized to zero. This function walks over them and populates them, effectively linking them into the current binary. """ - assert ModuleDefinition.GET_GLOBAL_VARIABLES_NAME in self.functionPointers - - orderedDefs = [ - self.globalVariableDefinitions[name] for name in sorted(self.globalVariableDefinitions) - ] - - pointers = ListOf(PointerTo(int))() - pointers.resize(len(orderedDefs)) - self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](pointers.pointerUnsafe(0)) + if variable_names is None: + i_vals = range(len(self.orderedDefs)) + else: + i_vals = [self.orderedDefNames.index(x) for x in variable_names] - for i in range(len(orderedDefs)): - assert pointers[i], f"Failed to get a pointer to {orderedDefs[i].name}" + for i in i_vals: + assert self.pointers[i], f"Failed to get a pointer to {self.orderedDefs[i].name}" - meta = orderedDefs[i].metadata + meta = SerializationContext().deserialize(self.orderedDefs[i]).metadata if meta.matches.StringConstant: - pointers[i].cast(str).initialize(meta.value) + self.pointers[i].cast(str).initialize(meta.value) if meta.matches.IntegerConstant: - pointers[i].cast(int).initialize(meta.value) + self.pointers[i].cast(int).initialize(meta.value) elif meta.matches.BytesConstant: - pointers[i].cast(bytes).initialize(meta.value) + self.pointers[i].cast(bytes).initialize(meta.value) elif meta.matches.PointerToPyObject: - pointers[i].cast(object).initialize(meta.value) + self.pointers[i].cast(object).initialize(meta.value) elif meta.matches.PointerToTypedPythonObject: - pointers[i].cast(meta.type).initialize(meta.value) + self.pointers[i].cast(meta.type).initialize(meta.value) elif meta.matches.PointerToTypedPythonObjectAsMemberOfDict: - pointers[i].cast(meta.type).initialize(meta.sourceDict[meta.name]) + self.pointers[i].cast(meta.type).initialize(meta.sourceDict[meta.name]) elif meta.matches.ClassMethodDispatchSlot: slotIx = _types.allocateClassMethodDispatch( @@ -111,17 +109,17 @@ def linkGlobalVariables(self): meta.argTupleType, meta.kwargTupleType ) - pointers[i].cast(int).initialize(slotIx) + self.pointers[i].cast(int).initialize(slotIx) elif meta.matches.IdOfPyObject: - pointers[i].cast(int).initialize(id(meta.value)) + self.pointers[i].cast(int).initialize(id(meta.value)) elif meta.matches.ClassVtable: - pointers[i].cast(int).initialize( + self.pointers[i].cast(int).initialize( _types._vtablePointer(meta.value) ) elif meta.matches.RawTypePointer: - pointers[i].cast(int).initialize( + self.pointers[i].cast(int).initialize( _types.getTypePointer(meta.value) ) diff --git a/typed_python/compiler/module_definition.py b/typed_python/compiler/module_definition.py index 4e9b35fde..cadbb2ec6 100644 --- a/typed_python/compiler/module_definition.py +++ b/typed_python/compiler/module_definition.py @@ -18,15 +18,19 @@ class ModuleDefinition: """A single module of compiled llvm code. - Members: - moduleText - a string containing the llvm IR for the module - functionList - a list of the names of exported functions - globalDefinitions - a dict from name to a GlobalDefinition + Attributes: + moduleText (str): a string containing the llvm IR for the module + functionList (list): a list of the names of exported functions + globalDefinitions (dict): a dict from name to a GlobalDefinition + globalDependencies (dict): a dict from function link_name to a list of globals the + function depends on + hash (str): The module hash, generated from the llvm IR. """ GET_GLOBAL_VARIABLES_NAME = ".get_global_variables" - def __init__(self, moduleText, functionNameToType, globalVariableDefinitions): + def __init__(self, moduleText, functionNameToType, globalVariableDefinitions, globalDependencies): self.moduleText = moduleText self.functionNameToType = functionNameToType self.globalVariableDefinitions = globalVariableDefinitions + self.globalDependencies = globalDependencies self.hash = sha_hash(moduleText) diff --git a/typed_python/compiler/native_ast_to_llvm.py b/typed_python/compiler/native_ast_to_llvm.py index 769b58f62..bbd2027d8 100644 --- a/typed_python/compiler/native_ast_to_llvm.py +++ b/typed_python/compiler/native_ast_to_llvm.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typed_python.compiler.native_ast as native_ast -from typed_python.compiler.module_definition import ModuleDefinition -from typed_python.compiler.global_variable_definition import GlobalVariableDefinition import llvmlite.ir import os - +import typed_python.compiler.native_ast as native_ast +from typed_python.compiler.global_variable_definition import GlobalVariableDefinition +from typed_python.compiler.module_definition import ModuleDefinition +from typing import Dict llvm_i8ptr = llvmlite.ir.IntType(8).as_pointer() llvm_i8 = llvmlite.ir.IntType(8) llvm_i32 = llvmlite.ir.IntType(32) @@ -512,7 +512,8 @@ def __init__(self, # dict from name to GlobalVariableDefinition self.globalDefinitions = globalDefinitions self.globalDefinitionLlvmValues = globalDefinitionLlvmValues - + # a list of the global LLVM names that the function depends on. + self.global_names = [] self.module = module self.converter = converter self.builder = builder @@ -631,7 +632,16 @@ def generate_exception_and_store_value(self, llvm_pointer_val): ) return self.builder.bitcast(exception_ptr, llvm_i8ptr) - def namedCallTargetToLLVM(self, target): + def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVMValue: + """ + Generate llvm IR code for a given target. + + There are three options for code generation: + 1. The target is external, i.e something like pyobj_len, np_add_traceback - system-level functions. We add to + external_function_references. + 2. The function is in function_definitions, in which case we grab the function definition and make an inlining decision. + 3. We have a compiler cache, and the function is in it. We add to external_function_references. + """ if target.external: if target.name not in self.external_function_references: func_type = llvmlite.ir.FunctionType( @@ -650,7 +660,6 @@ def namedCallTargetToLLVM(self, target): func = self.external_function_references[target.name] elif target.name in self.converter._function_definitions: func = self.converter._functions_by_name[target.name] - if func.module is not self.module: # first, see if we'd like to inline this module if ( @@ -664,7 +673,7 @@ def namedCallTargetToLLVM(self, target): func = self.external_function_references[target.name] else: - # TODO: decide whether to inline based on something in the compiler cache + # TODO (Will): decide whether to inline cached code assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name) # this function is defined in a shared object that we've loaded from a prior # invocation @@ -802,6 +811,7 @@ def _convert(self, expr): return self.stack_slots[expr.name] if expr.matches.GlobalVariable: + self.global_names.append(expr.name) if expr.name in self.globalDefinitions: assert expr.metadata == self.globalDefinitions[expr.name].metadata, ( expr.metadata, self.globalDefinitions[expr.name].metadata @@ -1485,11 +1495,11 @@ def define(fname, output, inputs, vararg=False): class Converter: - def __init__(self, compilerCache): + def __init__(self, compilerCache=None): object.__init__(self) self._modules = {} - self._functions_by_name = {} - self._function_definitions = {} + self._functions_by_name: Dict[str, llvmlite.ir.Function] = {} + self._function_definitions: Dict[str, native_ast.Function] = {} # total number of instructions in each function, by name self._function_complexity = {} @@ -1504,7 +1514,7 @@ def __init__(self, compilerCache): def totalFunctionComplexity(self, name): """Return the total number of instructions contained in a function. - The function must already have been defined in a prior parss. We use this + The function must already have been defined in a prior pass. We use this information to decide which functions to repeat in new module definitions. """ if name in self._function_complexity: @@ -1538,9 +1548,7 @@ def repeatFunctionInModule(self, name, module): assert isinstance(funcType, llvmlite.ir.FunctionType) self._functions_by_name[name] = llvmlite.ir.Function(module, funcType, name) - self._inlineRequests.append(name) - return self._functions_by_name[name] def add_functions(self, names_to_definitions): @@ -1596,7 +1604,8 @@ def add_functions(self, names_to_definitions): globalDefinitions = {} globalDefinitionsLlvmValues = {} - + # we need a separate dictionary owing to the possibility of global var reuse across functions. + globalDependencies = {} while names_to_definitions: for name in sorted(names_to_definitions): definition = names_to_definitions.pop(name) @@ -1613,7 +1622,7 @@ def add_functions(self, names_to_definitions): TypedLLVMValue(func.args[i], definition.args[i][1]) block = func.append_basic_block('entry') - builder = llvmlite.ir.IRBuilder(block) # this shares state with func + builder = llvmlite.ir.IRBuilder(block) try: func_converter = FunctionConverter( @@ -1634,6 +1643,8 @@ def add_functions(self, names_to_definitions): func_converter.finalize() + globalDependencies[func.name] = func_converter.global_names + if res is not None: assert res.llvm_value is None if definition.output_type != native_ast.Void: @@ -1667,7 +1678,8 @@ def add_functions(self, names_to_definitions): return ModuleDefinition( str(module), functionTypes, - globalDefinitions + globalDefinitions, + globalDependencies ) def defineGlobalMetadataAccessor(self, module, globalDefinitions, globalDefinitionsLlvmValues): diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index 6c23a92a8..eb9126475 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -17,6 +17,7 @@ from typed_python.hash import Hash from types import ModuleType +from typing import Dict from typed_python import Class import typed_python.python_ast as python_ast import typed_python._types as _types @@ -129,10 +130,10 @@ def __init__(self, llvmCompiler, compilerCache): self._link_name_for_identity = {} self._identity_for_link_name = {} - self._definitions = {} - self._targets = {} + self._definitions: Dict[str, native_ast.Function] = {} + self._targets: Dict[str, TypedCallTarget] = {} self._inflight_definitions = {} - self._inflight_function_conversions = {} + self._inflight_function_conversions: Dict[str, FunctionConversionContext] = {} self._identifier_to_pyfunc = {} self._times_calculated = {} @@ -194,12 +195,14 @@ def buildAndLinkNewModule(self): # get a set of function names that we depend on externallyUsed = set() + dependency_edgelist = [] for funcName in targets: ident = self._identity_for_link_name.get(funcName) if ident is not None: for dep in self._dependencies.getNamesDependedOn(ident): depLN = self._link_name_for_identity.get(dep) + dependency_edgelist.append([funcName, depLN]) if depLN not in targets: externallyUsed.add(depLN) @@ -208,7 +211,8 @@ def buildAndLinkNewModule(self): self.compilerCache.addModule( binary, {name: self.getTarget(name) for name in targets if self.hasTarget(name)}, - externallyUsed + externallyUsed, + dependency_edgelist, ) def extract_new_function_definitions(self): @@ -277,13 +281,14 @@ def deleteTarget(self, linkName): self._targets.pop(linkName) def setTarget(self, linkName, target): + assert(isinstance(target, TypedCallTarget)) self._targets[linkName] = target def getTarget(self, linkName): if linkName in self._targets: return self._targets[linkName] - if self.compilerCache.hasSymbol(linkName): + if self.compilerCache is not None and self.compilerCache.hasSymbol(linkName): return self.compilerCache.getTarget(linkName) return None @@ -293,7 +298,7 @@ def defineNonPythonFunction(self, name, identityTuple, context): name - the name to actually give the function. identityTuple - a unique (sha)hashable tuple - context - a FunctionConvertsionContext lookalike + context - a FunctionConversionContext lookalike returns a TypedCallTarget, or None if it's not known yet """ @@ -329,7 +334,7 @@ def defineNonPythonFunction(self, name, identityTuple, context): if self._currentlyConverting is None: # force the function to resolve immediately self._resolveAllInflightFunctions() - self._installInflightFunctions(name) + self._installInflightFunctions() self._inflight_function_conversions.clear() return self.getTarget(linkName) @@ -460,7 +465,7 @@ def generateCallConverter(self, callTarget: TypedCallTarget): return linkName # # we already defined it in another process so don't do it again - if self.compilerCache.hasSymbol(linkName): + if self.compilerCache is not None and self.compilerCache.hasSymbol(linkName): return linkName # N.B. there aren't targets for call converters. We make the definition directly. @@ -883,13 +888,12 @@ def convert( output_type, conversionType ) - self._inflight_function_conversions[identity] = functionConverter if isRoot: try: self._resolveAllInflightFunctions() - self._installInflightFunctions(name) + self._installInflightFunctions() return self.getTarget(name) finally: self._inflight_function_conversions.clear() @@ -901,10 +905,11 @@ def convert( # return None, which will cause callers to replace this with a throw # until we have had a chance to do a full pass of conversion. if self.getTarget(name) is not None: - raise Exception("This code looks unreachable to me...") + raise RuntimeError(f"Unexpected conversion error for {name}") return None - def _installInflightFunctions(self, name): + def _installInflightFunctions(self): + """Add all function definitions corresponding to keys in inflight_function_conversions to the relevant dictionaries.""" if VALIDATE_FUNCTION_DEFINITIONS_STABLE: # this should always be true, but its expensive so we have it off by default for identifier, functionConverter in self._inflight_function_conversions.items(): From 22f01c539382a4100928192082b28c1a87ab7d6f Mon Sep 17 00:00:00 2001 From: William Grant Date: Thu, 19 Jan 2023 19:15:55 -0500 Subject: [PATCH 3/7] Ensure that serializer can use the 'name' pickle protocol. Pickle supports a protocol where __reduce__returns a string giving the global name. Implementing this behaviour lets us serialize numpy ufuncs. Also adjust installInflightFunctions to handle new load behaviour, fix an instability caused by not leaving LoadedModule objects in memory, and adjust alternative test. --- typed_python/SerializationContext.py | 44 ++++++++++++++----- .../compiler/global_variable_definition.py | 3 +- typed_python/compiler/llvm_compiler_test.py | 27 ++++++++++++ typed_python/compiler/loaded_module.py | 4 ++ .../compiler/python_to_native_converter.py | 13 +++++- .../compiler/tests/numpy_interaction_test.py | 11 ++++- .../type_of_instances_compilation_test.py | 6 +-- typed_python/types_serialization_test.py | 24 ++++++++++ 8 files changed, 113 insertions(+), 19 deletions(-) diff --git a/typed_python/SerializationContext.py b/typed_python/SerializationContext.py index b110be902..5b449d709 100644 --- a/typed_python/SerializationContext.py +++ b/typed_python/SerializationContext.py @@ -31,11 +31,24 @@ import types import traceback import logging +import numpy +import pickle _badModuleCache = set() +def pickledByStr(module_name: str, name: str) -> None: + """Generate the object given the module_name and name. + + This mimics pickle's behavior when given a string from __reduce__. The + string is interpreted as the name of a global variable, and pickle.whichmodules + is used to search the module namespace, generating module_name. + """ + module = importlib.import_module(module_name) + return getattr(module, name) + + def createFunctionWithLocalsAndGlobals(code, globals): if globals is None: globals = {} @@ -708,26 +721,30 @@ def walkCodeObject(code): return (createFunctionWithLocalsAndGlobals, args, representation) if not isinstance(inst, type) and hasattr(type(inst), '__reduce_ex__'): - res = inst.__reduce_ex__(4) + if isinstance(inst, numpy.ufunc): + res = inst.__name__ + else: + res = inst.__reduce_ex__(4) - # pickle supports a protocol where __reduce__ can return a string - # giving a global name. We'll already find that separately, so we - # don't want to handle it here. We ought to look at this in more detail - # however + # mimic pickle's behaviour when a string is received. if isinstance(res, str): - return None + name_tuple = (inst, res) + module_name = pickle.whichmodule(*name_tuple) + res = (pickledByStr, (module_name, res,), pickledByStr) return res if not isinstance(inst, type) and hasattr(type(inst), '__reduce__'): - res = inst.__reduce__() + if isinstance(inst, numpy.ufunc): + res = inst.__name__ + else: + res = inst.__reduce() - # pickle supports a protocol where __reduce__ can return a string - # giving a global name. We'll already find that separately, so we - # don't want to handle it here. We ought to look at this in more detail - # however + # mimic pickle's behaviour when a string is received. if isinstance(res, str): - return None + name_tuple = (inst, res) + module_name = pickle.whichmodule(*name_tuple) + res = (pickledByStr, (module_name, res,), pickledByStr) return res @@ -736,6 +753,9 @@ def walkCodeObject(code): def setInstanceStateFromRepresentation( self, instance, representation=None, itemIt=None, kvPairIt=None, setStateFun=None ): + if representation is pickledByStr: + return + if representation is reconstructTypeFunctionType: return diff --git a/typed_python/compiler/global_variable_definition.py b/typed_python/compiler/global_variable_definition.py index b95dca82b..4dbf34f8b 100644 --- a/typed_python/compiler/global_variable_definition.py +++ b/typed_python/compiler/global_variable_definition.py @@ -87,5 +87,4 @@ def __eq__(self, other): return self.name == other.name and self.type == other.type and self.metadata == other.metadata def __str__(self): - metadata_str = str(self.metadata) if len(str(self.metadata)) < 100 else str(self.metadata)[:100] + "..." - return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={metadata_str})" + return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={pad(str(self.metadata))})" diff --git a/typed_python/compiler/llvm_compiler_test.py b/typed_python/compiler/llvm_compiler_test.py index fb1911e00..d914bae4f 100644 --- a/typed_python/compiler/llvm_compiler_test.py +++ b/typed_python/compiler/llvm_compiler_test.py @@ -20,6 +20,8 @@ from typed_python.compiler.module_definition import ModuleDefinition from typed_python.compiler.global_variable_definition import GlobalVariableMetadata +from typed_python.test_util import evaluateExprInFreshProcess + import pytest import ctypes @@ -131,3 +133,28 @@ def test_create_binary_shared_object(): pointers[0].set(5) assert loaded.functionPointers['__test_f_2']() == 5 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_loaded_modules_persist(): + """ + Make sure that loaded modules are persisted in the converter state. + + We have to maintain these references to avoid surprise segfaults - if this test fails, + it should be because the GlobalVariableDefinition memory management has been refactored. + """ + + # compile a module + xmodule = "\n".join([ + "@Entrypoint", + "def f(x):", + " return x + 1", + "@Entrypoint", + "def g(x):", + " return f(x) * 100", + "g(1000)", + "def get_loaded_modules():", + " return len(Runtime.singleton().converter.loadedUncachedModules)" + ]) + VERSION1 = {'x.py': xmodule} + assert evaluateExprInFreshProcess(VERSION1, 'x.get_loaded_modules()') == 1 diff --git a/typed_python/compiler/loaded_module.py b/typed_python/compiler/loaded_module.py index 01b9ec814..ffb2112c0 100644 --- a/typed_python/compiler/loaded_module.py +++ b/typed_python/compiler/loaded_module.py @@ -28,6 +28,8 @@ def __init__(self, functionPointers, serializedGlobalVariableDefinitions): self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](self.pointers.pointerUnsafe(0)) + self.installedGlobalVariableDefinitions = {} + @staticmethod def validateGlobalVariables(serializedGlobalVariableDefinitions: Dict[str, bytes]) -> bool: """Check that each global variable definition is sensible. @@ -83,6 +85,8 @@ def linkGlobalVariables(self, variable_names: List[str] = None) -> None: meta = SerializationContext().deserialize(self.orderedDefs[i]).metadata + self.installedGlobalVariableDefinitions[i] = meta + if meta.matches.StringConstant: self.pointers[i].cast(str).initialize(meta.value) diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index eb9126475..473c603a2 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -125,6 +125,12 @@ def __init__(self, llvmCompiler, compilerCache): self.llvmCompiler = llvmCompiler self.compilerCache = compilerCache + # all LoadedModule objects that we have created. We need to keep them alive so + # that any python metadata objects the've created stay alive as well. Ultimately, this + # may not be the place we put these objects (for instance, you could imagine a + # 'dummy' compiler cache or something). But for now, we need to keep them alive. + self.loadedUncachedModules = [] + # if True, then insert additional code to check for undefined behavior. self.generateDebugChecks = False @@ -191,6 +197,7 @@ def buildAndLinkNewModule(self): if self.compilerCache is None: loadedModule = self.llvmCompiler.buildModule(targets) loadedModule.linkGlobalVariables() + self.loadedUncachedModules.append(loadedModule) return # get a set of function names that we depend on @@ -926,7 +933,11 @@ def _installInflightFunctions(self): outboundTargets = [] for outboundFuncId in self._dependencies.getNamesDependedOn(identifier): name = self._link_name_for_identity[outboundFuncId] - outboundTargets.append(self._targets[name]) + target = self.getTarget(name) + if target is not None: + outboundTargets.append(target) + else: + raise RuntimeError(f'dependency not found for {name}.') nativeFunction, actual_output_type = self._inflight_definitions.get(identifier) diff --git a/typed_python/compiler/tests/numpy_interaction_test.py b/typed_python/compiler/tests/numpy_interaction_test.py index f15bfea9d..db774c6d2 100644 --- a/typed_python/compiler/tests/numpy_interaction_test.py +++ b/typed_python/compiler/tests/numpy_interaction_test.py @@ -1,4 +1,4 @@ -from typed_python import ListOf, Entrypoint +from typed_python import ListOf, Entrypoint, SerializationContext import numpy import numpy.linalg @@ -44,3 +44,12 @@ def test_listof_from_sliced_numpy_array(): y = x[::2] assert ListOf(int)(y) == [0, 2] + + +def test_can_serialize_numpy_ufunc(): + assert numpy.sin == SerializationContext().deserialize(SerializationContext().serialize(numpy.sin)) + + +def test_can_serialize_numpy_array(): + x = numpy.ones(10) + assert (x == SerializationContext().deserialize(SerializationContext().serialize(x))).all() diff --git a/typed_python/compiler/tests/type_of_instances_compilation_test.py b/typed_python/compiler/tests/type_of_instances_compilation_test.py index 337bdc3ff..c3fdf459c 100644 --- a/typed_python/compiler/tests/type_of_instances_compilation_test.py +++ b/typed_python/compiler/tests/type_of_instances_compilation_test.py @@ -17,13 +17,13 @@ def typeOfArg(x: C): def test_type_of_alternative_is_specific(): for members in [{}, {'a': int}]: - A = Alternative("A", A=members) + Alt = Alternative("Alt", A=members) @Entrypoint - def typeOfArg(x: A): + def typeOfArg(x: Alt): return type(x) - assert typeOfArg(A.A()) is A.A + assert typeOfArg(Alt.A()) is Alt.A def test_type_of_concrete_alternative_is_specific(): diff --git a/typed_python/types_serialization_test.py b/typed_python/types_serialization_test.py index c5f33c2a1..fba4b0ee6 100644 --- a/typed_python/types_serialization_test.py +++ b/typed_python/types_serialization_test.py @@ -3061,3 +3061,27 @@ def f(self): print(x) # TODO: make this True # assert x[0].f.__closure__[0].cell_contents is x + + def test_serialize_pyobj_with_custom_reduce(self): + class CustomReduceObject: + def __reduce__(self): + return 'CustomReduceObject' + + assert CustomReduceObject == SerializationContext().deserialize(SerializationContext().serialize(CustomReduceObject)) + + def test_serialize_pyobj_in_MRTG_with_custom_reduce(self): + def getX(): + class InnerCustomReduceObject: + def __reduce__(self): + return 'InnerCustomReduceObject' + + def f(self): + return x + + x = (InnerCustomReduceObject, InnerCustomReduceObject) + + return x + + x = callFunctionInFreshProcess(getX, (), showStdout=True) + + assert x == SerializationContext().deserialize(SerializationContext().serialize(x)) From 588610a01d36d5b5cb5457c7d4482190a69f5b39 Mon Sep 17 00:00:00 2001 From: Aaron Levy Date: Fri, 10 Feb 2023 14:57:02 +0000 Subject: [PATCH 4/7] Fix quarter attribute of Date --- typed_python/lib/datetime/date_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/typed_python/lib/datetime/date_time.py b/typed_python/lib/datetime/date_time.py index 634a624c7..5b189b6f8 100644 --- a/typed_python/lib/datetime/date_time.py +++ b/typed_python/lib/datetime/date_time.py @@ -269,7 +269,7 @@ def firstOfMonth(self): return Date(self.year, self.month, 1) def quarterOfYear(self): - return (self.date.month - 1) // 3 + 1 + return (self.month - 1) // 3 + 1 def next(self, step: int = 1): """Returns the date `step` days ahead of `self`. From 829fbe41acec0a2bdafc7bd05d43a2d002f9a6ca Mon Sep 17 00:00:00 2001 From: Aaron Levy Date: Wed, 15 Feb 2023 23:05:03 +0000 Subject: [PATCH 5/7] Remove date_time stuff from TP __init__.py --- typed_python/__init__.py | 2 -- typed_python/lib/datetime/date_time_test.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/typed_python/__init__.py b/typed_python/__init__.py index 6c4d70706..ca8495770 100644 --- a/typed_python/__init__.py +++ b/typed_python/__init__.py @@ -79,8 +79,6 @@ from typed_python.lib.map import map # noqa from typed_python.lib.pmap import pmap # noqa from typed_python.lib.reduce import reduce # noqa -from typed_python.lib.timestamp import Timestamp # noqa -from typed_python.lib.datetime.date_time import UTC, NYC, TimeOfDay, DateTime, Date, PytzTimezone # noqa _types.initializeGlobalStatics() diff --git a/typed_python/lib/datetime/date_time_test.py b/typed_python/lib/datetime/date_time_test.py index 1e7e3f1be..1f14deb99 100644 --- a/typed_python/lib/datetime/date_time_test.py +++ b/typed_python/lib/datetime/date_time_test.py @@ -1,6 +1,7 @@ import pytest import pytz import datetime +from typed_python import NamedTuple from typed_python.lib.datetime.date_time import ( Date, DateTime, @@ -15,8 +16,8 @@ OneFoldOnlyError, PytzTimezone, ) -from typed_python import Timestamp, NamedTuple from typed_python.lib.datetime.date_parser_test import get_datetimes_in_range +from typed_python.lib.timestamp import Timestamp def test_last_weekday_of_month(): From e78ec6720c067ad92d80d2fddbde2748e5fe7eca Mon Sep 17 00:00:00 2001 From: Braxton Mckee Date: Thu, 16 Feb 2023 19:00:58 +0000 Subject: [PATCH 6/7] Ensure that the new pickle protocol support works with 'local names' (e.g. dotted method names). --- typed_python/SerializationContext.py | 9 ++++++++- typed_python/types_serialization_test.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/typed_python/SerializationContext.py b/typed_python/SerializationContext.py index 5b449d709..832bbb415 100644 --- a/typed_python/SerializationContext.py +++ b/typed_python/SerializationContext.py @@ -44,9 +44,16 @@ def pickledByStr(module_name: str, name: str) -> None: This mimics pickle's behavior when given a string from __reduce__. The string is interpreted as the name of a global variable, and pickle.whichmodules is used to search the module namespace, generating module_name. + + Note that 'name' might contain '.' inside of it, since its a 'local name'. """ module = importlib.import_module(module_name) - return getattr(module, name) + + instance = module + for subName in name.split('.'): + instance = getattr(instance, subName) + + return instance def createFunctionWithLocalsAndGlobals(code, globals): diff --git a/typed_python/types_serialization_test.py b/typed_python/types_serialization_test.py index fba4b0ee6..e34da54b5 100644 --- a/typed_python/types_serialization_test.py +++ b/typed_python/types_serialization_test.py @@ -15,6 +15,8 @@ import sys import os import importlib +from functools import lru_cache + from abc import ABC, abstractmethod, ABCMeta from typed_python.test_util import callFunctionInFreshProcess import typed_python.compiler.python_ast_util as python_ast_util @@ -57,6 +59,13 @@ module_level_testfun = dummy_test_module.testfunction +class GlobalClassWithLruCache: + @staticmethod + @lru_cache(maxsize=None) + def f(x): + return x + + def moduleLevelFunctionUsedByExactlyOneSerializationTest(): return "please don't touch me" @@ -3085,3 +3094,10 @@ def f(self): x = callFunctionInFreshProcess(getX, (), showStdout=True) assert x == SerializationContext().deserialize(SerializationContext().serialize(x)) + + def test_serialize_class_static_lru_cache(self): + s = SerializationContext() + + assert ( + s.deserialize(s.serialize(GlobalClassWithLruCache.f)) is GlobalClassWithLruCache.f + ) From 4843289cead0aecd8edd70d685ca85753405ea6e Mon Sep 17 00:00:00 2001 From: Braxton Mckee Date: Wed, 1 Feb 2023 08:33:02 +0000 Subject: [PATCH 7/7] Prune the graph of inflight functions to not include the ones we don't need. When we first call a python function 'f' with a specific set of arguments, we may not know its return type the first time we try to convert it. To ensure we have a stable typing graph, we repeatedly update the active functions in our graph until the type graph is stable. This can lead to many copies of the same function, or even multiple signatures of the same function, only one of which we'll use. This change prunes those away before we submit them to the LLVM layer. --- typed_python/compiler/directed_graph.py | 4 ++ .../compiler/python_to_native_converter.py | 42 +++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/typed_python/compiler/directed_graph.py b/typed_python/compiler/directed_graph.py index 5cf2f03ea..7cc89b63a 100644 --- a/typed_python/compiler/directed_graph.py +++ b/typed_python/compiler/directed_graph.py @@ -44,6 +44,10 @@ def hasEdge(self, source, dest): return False return dest in self.sourceToDest[source] + def clearOutgoing(self, node): + for child in list(self.outgoing(node)): + self.dropEdge(node, child) + def outgoing(self, node): return self.sourceToDest.get(node, set()) diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index 473c603a2..237fdb1dc 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -58,6 +58,26 @@ def __init__(self): # (priority, node) pairs that need to recompute self._dirty_inflight_functions_with_order = SortedSet(key=lambda pair: pair[0]) + def reachableInSet(self, rootIdentity, activeSet): + """Produce the subset of 'activeSet' that are reachable from 'rootIdentity'""" + reachable = set() + + def walk(node): + if node in reachable or node not in activeSet: + return + + reachable.add(node) + + for child in self._dependencies.outgoing(node): + walk(child) + + walk(rootIdentity) + + return reachable + + def clearOutgoingEdgesFor(self, identity): + self._dependencies.clearOutgoing(identity) + def dropNode(self, node): self._dependencies.dropNode(node, False) if node in self._identity_levels: @@ -341,8 +361,7 @@ def defineNonPythonFunction(self, name, identityTuple, context): if self._currentlyConverting is None: # force the function to resolve immediately self._resolveAllInflightFunctions() - self._installInflightFunctions() - self._inflight_function_conversions.clear() + self._installInflightFunctions(identity) return self.getTarget(linkName) @@ -540,6 +559,8 @@ def _resolveAllInflightFunctions(self): # this calls back into convert with dependencies # they get registered as dirty + self._dependencies.clearOutgoingEdgesFor(identity) + nativeFunction, actual_output_type = functionConverter.convertToNativeFunction() if nativeFunction is not None: @@ -900,7 +921,7 @@ def convert( if isRoot: try: self._resolveAllInflightFunctions() - self._installInflightFunctions() + self._installInflightFunctions(identity) return self.getTarget(name) finally: self._inflight_function_conversions.clear() @@ -915,7 +936,7 @@ def convert( raise RuntimeError(f"Unexpected conversion error for {name}") return None - def _installInflightFunctions(self): + def _installInflightFunctions(self, rootIdentity): """Add all function definitions corresponding to keys in inflight_function_conversions to the relevant dictionaries.""" if VALIDATE_FUNCTION_DEFINITIONS_STABLE: # this should always be true, but its expensive so we have it off by default @@ -929,7 +950,17 @@ def _installInflightFunctions(self): finally: self._currentlyConverting = None + # restrict to the set of inflight functions that are reachable from rootName + # we produce copies of functions that we don't actually need to compile during + # early phases of type inference + reachable = self._dependencies.reachableInSet( + rootIdentity, + set(self._inflight_function_conversions) + ) + for identifier, functionConverter in self._inflight_function_conversions.items(): + if identifier not in reachable: + continue outboundTargets = [] for outboundFuncId in self._dependencies.getNamesDependedOn(identifier): name = self._link_name_for_identity[outboundFuncId] @@ -987,3 +1018,6 @@ def _installInflightFunctions(self): self._definitions[name] = nativeFunction self._new_native_functions.add(name) + + self._inflight_definitions.clear() + self._inflight_function_conversions.clear()