diff --git a/README.md b/README.md index 0e7700e..18d141c 100755 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Check out the [PyReason Hello World](https://pyreason.readthedocs.io/en/latest/t ## 1. Introduction PyReason is a graphical inference tool that uses a set of logical rules and facts (initial conditions) to reason over graph structures. To get more details, refer to the paper/video/hello-world-example mentioned above. - + ## 2. Documentation All API documentation and code examples can be found on [ReadTheDocs](https://pyreason.readthedocs.io/en/latest/) diff --git a/docs/group-chat-example.md b/docs/group-chat-example.md index c2aaa0c..2e29e57 100755 --- a/docs/group-chat-example.md +++ b/docs/group-chat-example.md @@ -3,7 +3,7 @@ Here is an example that utilizes custom thresholds. The following graph represents a network of People and a Text Message in their group chat. - + In this case, we want to know when a text message has been viewed by all members of the group chat. @@ -14,7 +14,7 @@ First, lets create the group chat. import networkx as nx # Create an empty graph -G = nx.Graph() +G = nx.DiGraph() # Add nodes nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"] @@ -35,7 +35,7 @@ G.add_edges_from(edges) Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows: ```text -ViewedByAll(x) <- HaveAccess(x,y), Viewed(y) +ViewedByAll(y) <- HaveAccess(x,y), Viewed(x) ``` The `head` of the rule is `ViewedByAll(x)` and the body is `HaveAccess(x,y), Viewed(y)`. The head and body are separated by an arrow which means the rule will start evaluating from @@ -79,10 +79,10 @@ We add the facts in PyReason as below: ```python import pyreason as pr -pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 0, static=True)) -pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 0, static=True)) -pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 1, static=True)) -pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 2, static=True)) +pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3)) +pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3)) +pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3)) +pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3)) ``` This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds diff --git a/docs/hello-world.md b/docs/hello-world.md index f2bf131..cc4a75c 100755 --- a/docs/hello-world.md +++ b/docs/hello-world.md @@ -88,7 +88,7 @@ We add a fact in PyReason like so: ```python import pyreason as pr -pr.add_fact(pr.Fact(name='popular-fact', component='Mary', attribute='popular', bound=[1, 1], start_time=0, end_time=2)) +pr.add_fact(pr.Fact(fact_text='popular(Mary) : true', name='popular_fact', start_time=0, end_time=2)) ``` This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds diff --git a/docs/source/_static/pyreason_logo.jpg b/docs/source/_static/pyreason_logo.jpg new file mode 100755 index 0000000..233618a Binary files /dev/null and b/docs/source/_static/pyreason_logo.jpg differ diff --git a/docs/source/about.rst b/docs/source/about.rst new file mode 100644 index 0000000..ce1614b --- /dev/null +++ b/docs/source/about.rst @@ -0,0 +1,36 @@ +About PyReason +============== + +**PyReason** is a modern Python-based software framework designed for open-world temporal logic reasoning using generalized annotated logic. It addresses the growing needs of neuro-symbolic reasoning frameworks that incorporate differentiable logics and temporal extensions, allowing inference over finite periods with open-world capabilities. PyReason is particularly suited for reasoning over graphical structures such as knowledge graphs, social networks, and biological networks, offering fully explainable inference processes. + +Key Capabilities +-------------- + +1. **Graph-Based Reasoning**: PyReason supports direct reasoning over knowledge graphs, a popular representation of symbolic data. Unlike black-box frameworks, PyReason provides full explainability of the reasoning process. + +2. **Annotated Logic**: It extends classical logic with annotations, supporting various types of logic including fuzzy logic, real-valued intervals, and temporal logic. PyReason's framework goes beyond traditional logic systems like Prolog, allowing for arbitrary functions over reals, enhancing its capability to handle constructs in neuro-symbolic reasoning. + +3. **Temporal Reasoning**: PyReason includes temporal extensions to handle reasoning over sequences of time points. This feature enables the creation of rules that incorporate temporal dependencies, such as "if condition A, then condition B after a certain number of time steps." + +4. **Open World Reasoning**: Unlike closed-world assumptions where anything not explicitly stated is false, PyReason considers unknowns as a valid state, making it more flexible and suitable for real-world applications where information may be incomplete. + +5. **Handling Logical Inconsistencies**: PyReason can detect and resolve inconsistencies in the reasoning process. When inconsistencies are found, it can reset affected interpretations to a state of complete uncertainty, ensuring that the reasoning process remains robust. + +6. **Scalability and Performance**: PyReason is optimized for scalability, supporting exact deductive inference with memory-efficient implementations. It leverages sparsity in graphical structures and employs predicate-constant type checking to reduce computational complexity. + +7. **Explainability**: All inference results produced by PyReason are fully explainable, as the software maintains a trace of the inference steps that led to each conclusion. This feature is critical for applications where transparency of the reasoning process is necessary. + +8. **Integration and Extensibility**: PyReason is implemented in Python and supports integration with other tools and frameworks, making it easy to extend and adapt for specific needs. It can work with popular graph formats like GraphML and is compatible with tools like NetworkX and Neo4j. + +Use Cases +-------------- + +- **Knowledge Graph Reasoning**: PyReason can be used to perform logical inferences over knowledge graphs, aiding in tasks like knowledge completion, entity classification, and relationship extraction. + +- **Temporal Logic Applications**: Its temporal reasoning capabilities are useful in domains requiring time-based analysis, such as monitoring system states over time, or reasoning about events and their sequences. + +- **Social and Biological Network Analysis**: PyReason's support for annotated logic and reasoning over complex network structures makes it suitable for applications in social network analysis, supply chain management, and biological systems modeling. + +PyReason is open-source and available at: [GitHub - PyReason](https://github.com/lab-v2/pyreason). + +For more detailed information on PyReason’s logical framework, implementation details, and experimental results, refer to the full documentation or visit the project's GitHub repository. diff --git a/docs/source/api_reference/index.rst b/docs/source/api_reference/index.rst new file mode 100644 index 0000000..49ae7af --- /dev/null +++ b/docs/source/api_reference/index.rst @@ -0,0 +1,10 @@ +API Reference +========== + +In this section we outline the API Reference for the `pyreason` library. + +Contents +-------- +.. toctree:: + :maxdepth: 2 + :caption: Contents: \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 4cd32c3..4e03901 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,15 +3,32 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to Pyreason's documentation! +Welcome to PyReason Docs! ==================================== +.. image:: _static/pyreason_logo.jpg + :alt: PyReason Logo + :align: center + +Introduction +------------ +Welcome to the documentation for **PyReason**, a powerful, optimized Python tool for Reasoning over Graphs. PyReason supports a variety of Logics such as Propositional, First Order, Annotated. This documentation will guide you through the installation, usage and API. + .. toctree:: - :caption: Tutorials - :maxdepth: 2 - :glob: + :maxdepth: 1 + :caption: Contents: + + about + installation + user_guide/index + api_reference/index + tutorials/index + license + - ./tutorials/* +Getting Help +------------ +If you encounter any issues or have questions, feel free to check our Github, or contact one of the authors (`dyuman.aditya@asu.edu`, `kmukher2@asu.edu`). Indices and tables ================== diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 0000000..aca55d1 --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,5 @@ +Installation +========== + +TODO: Add installation instructions here. + diff --git a/docs/source/license.rst b/docs/source/license.rst new file mode 100644 index 0000000..9e0cb0f --- /dev/null +++ b/docs/source/license.rst @@ -0,0 +1,4 @@ +License +========== + +TODO: Add license information here. \ No newline at end of file diff --git a/docs/source/tutorials/Advanced tutorial.rst b/docs/source/tutorials/advanced_tutorial.rst similarity index 100% rename from docs/source/tutorials/Advanced tutorial.rst rename to docs/source/tutorials/advanced_tutorial.rst diff --git a/docs/source/tutorials/Basic tutorial.rst b/docs/source/tutorials/basic_tutorial.rst similarity index 100% rename from docs/source/tutorials/Basic tutorial.rst rename to docs/source/tutorials/basic_tutorial.rst diff --git a/docs/source/tutorials/Creating Rules.rst b/docs/source/tutorials/creating_rules.rst similarity index 100% rename from docs/source/tutorials/Creating Rules.rst rename to docs/source/tutorials/creating_rules.rst diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst new file mode 100644 index 0000000..0d2c5bf --- /dev/null +++ b/docs/source/tutorials/index.rst @@ -0,0 +1,14 @@ +Tutorials +========== + +In this section we outline a series of tutorials that will help you get started with the basics of using the `pyreason` library. + +Contents +-------- + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + :glob: + + * \ No newline at end of file diff --git a/docs/source/tutorials/Installation.rst b/docs/source/tutorials/installation.rst similarity index 100% rename from docs/source/tutorials/Installation.rst rename to docs/source/tutorials/installation.rst diff --git a/docs/source/tutorials/Rule_image.png b/docs/source/tutorials/rule_image.png similarity index 100% rename from docs/source/tutorials/Rule_image.png rename to docs/source/tutorials/rule_image.png diff --git a/docs/source/tutorials/Understanding Logic.rst b/docs/source/tutorials/understanding_logic.rst similarity index 98% rename from docs/source/tutorials/Understanding Logic.rst rename to docs/source/tutorials/understanding_logic.rst index 28cb625..cf8df84 100644 --- a/docs/source/tutorials/Understanding Logic.rst +++ b/docs/source/tutorials/understanding_logic.rst @@ -50,4 +50,4 @@ Inconsistent predicate list The first rule states that the grass is wet if it rained, while the second rule states that the grass is not wet if it rained. The fact f1 states that it rained, which is consistent with the first rule, but inconsistent with the second rule. -.. |rule_image| image:: Rule_image.png +.. |rule_image| image:: rule_image.png diff --git a/docs/source/user_guide/annotation_functions.rst b/docs/source/user_guide/annotation_functions.rst new file mode 100644 index 0000000..e69de29 diff --git a/docs/source/user_guide/inconsistent_predicate_list.rst b/docs/source/user_guide/inconsistent_predicate_list.rst new file mode 100644 index 0000000..e69de29 diff --git a/docs/source/user_guide/index.rst b/docs/source/user_guide/index.rst new file mode 100644 index 0000000..9a4db18 --- /dev/null +++ b/docs/source/user_guide/index.rst @@ -0,0 +1,20 @@ +User Guide +========== + +In this section we demonstrate the functionality of the `pyreason` library and how to use it. + +Contents +-------- +.. toctree:: + :maxdepth: 2 + :caption: Contents: + :glob: + + * + annotation_functions + inconsistent_predicate_list + pyreason_facts + pyreason_graph + pyreason_rules + pyreason_settings + diff --git a/docs/source/user_guide/pyreason_facts.rst b/docs/source/user_guide/pyreason_facts.rst new file mode 100644 index 0000000..e69de29 diff --git a/docs/source/user_guide/pyreason_graphs.rst b/docs/source/user_guide/pyreason_graphs.rst new file mode 100644 index 0000000..0e5c8ed --- /dev/null +++ b/docs/source/user_guide/pyreason_graphs.rst @@ -0,0 +1,137 @@ +PyReason Graphs +=============== + +PyReason supports direct reasoning over knowledge graphs. PyReason graphs have full explainability of the reasoning process. Graphs serve as the knowledge base for PyReason, allowing users to create visual representations based on rules, relationships, and connections. + +Methods for Creating Graphs +--------------------------- +In PyReason there are two ways to create graphs: Networkx and GraphMl +Networkx allows you to manually add nodes and edges, whereas GraphMl reads in a directed graph from a file. + + +Networkx Example +---------------- +Using Networkx, you can create a ** `directed `_ ** graph object. Users can add and remove nodes and edges from the graph. + +Read more about Networkx `here `_. + +The following graph represents a network of people and the pets that +they own. + +1. Mary is friends with Justin +2. Mary is friends with John +3. Justin is friends with John + +And + +1. Mary owns a cat +2. Justin owns a cat and a dog +3. John owns a dog + +.. code:: python + import networkx as nx + + # ================================ CREATE GRAPH==================================== + # Create a Directed graph + g = nx.DiGraph() + + # Add the nodes + g.add_nodes_from(['John', 'Mary', 'Justin']) + g.add_nodes_from(['Dog', 'Cat']) + + # Add the edges and their attributes. When an attribute = x which is <= 1, the annotation + # associated with it will be [x,1]. NOTE: These attributes are immutable + # Friend edges + g.add_edge('Justin', 'Mary', Friends=1) + g.add_edge('John', 'Mary', Friends=1) + g.add_edge('John', 'Justin', Friends=1) + + # Pet edges + g.add_edge('Mary', 'Cat', owns=1) + g.add_edge('Justin', 'Cat', owns=1) + g.add_edge('Justin', 'Dog', owns=1) + g.add_edge('John', 'Dog', owns=1) + +After the graph has been created, it can be loaded with: + +.. code:: python + + import pyreason as pr + pr.load_graph(graph: nx.DiGraph) + + +Additional Considerations: +-------------------------- +Attributes to Bounds: + +In Networkx, each graph, node, and edge can hold key/value attribute pairs in an associated attribute dictionary (the keys must be hashable). + +In PyReason, these attributes get transformed into "bounds". The attribute value in Networkx, is translated into the lower bound in PyReason. + +.. code:: python + import networkx as nx + g = nx.DiGraph() + g.add_node("some_node", attribute1=1, attribute2="0,0") + +When the graph is loaded, "some_node" is given the attribute1: [1,1], and attribute2 : [0,0]. + +If the attribute is a simple value, it is treated as both the lower and upper bound in PyReason. If a specific pair of bounds is required (e.g., for coordinates or ranges), the value should be provided as a string in a specific format. + + + +GraphMl Example +--------------- +Using `GraphMl `_, you can read a graph in from a file. + +.. code:: xml + + + + + + + + + + + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + + +Then load the graph using the following: + +.. code:: python + + import pyreason as pr + pr.load_graphml('path_to_file') + +Graph Output: + +.. code:: python + +.. figure:: basic_graph.png + :alt: image + diff --git a/docs/source/user_guide/pyreason_rules.rst b/docs/source/user_guide/pyreason_rules.rst new file mode 100644 index 0000000..e69de29 diff --git a/docs/source/user_guide/pyreason_settings.rst b/docs/source/user_guide/pyreason_settings.rst new file mode 100644 index 0000000..e69de29 diff --git a/media/group_chat_graph.png b/media/group_chat_graph.png index dc9afac..9da3f58 100644 Binary files a/media/group_chat_graph.png and b/media/group_chat_graph.png differ diff --git a/pyreason/__init__.py b/pyreason/__init__.py index 85f8319..15fec58 100755 --- a/pyreason/__init__.py +++ b/pyreason/__init__.py @@ -8,23 +8,33 @@ from pyreason.pyreason import * import yaml +from importlib.metadata import version +from pkg_resources import get_distribution, DistributionNotFound + +try: + __version__ = get_distribution(__name__).version +except DistributionNotFound: + # package is not installed + pass with open(cache_status_path) as file: cache_status = yaml.safe_load(file) if not cache_status['initialized']: - print('Imported PyReason for the first time. Initializing ... this will take a minute') + print('Imported PyReason for the first time. Initializing caches for faster runtimes ... this will take a minute') graph_path = os.path.join(package_path, 'examples', 'hello-world', 'friends_graph.graphml') settings.verbose = False load_graphml(graph_path) add_rule(Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - add_fact(Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2)) + add_fact(Fact('popular(Mary)', 'popular_fact', 0, 2)) reason(timesteps=2) reset() reset_rules() + print('PyReason initialized!') + print() # Update cache status cache_status['initialized'] = True diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index 48ae672..c1c71e4 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -5,6 +5,7 @@ import sys import pandas as pd import memory_profiler as mp +import warnings from typing import List, Type, Callable, Tuple from pyreason.scripts.utils.output import Output @@ -21,11 +22,31 @@ import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval +from pyreason.scripts.utils.reorder_clauses import reorder_clauses # USER VARIABLES class _Settings: def __init__(self): + self.__verbose = None + self.__output_to_file = None + self.__output_file_name = None + self.__graph_attribute_parsing = None + self.__abort_on_inconsistency = None + self.__memory_profile = None + self.__reverse_digraph = None + self.__atom_trace = None + self.__save_graph_attributes_to_trace = None + self.__canonical = None + self.__inconsistency_check = None + self.__static_graph_facts = None + self.__store_interpretation_changes = None + self.__parallel_computing = None + self.__update_mode = None + self.__allow_ground_rules = None + self.reset() + + def reset(self): self.__verbose = True self.__output_to_file = False self.__output_file_name = 'pyreason_output' @@ -41,6 +62,7 @@ def __init__(self): self.__store_interpretation_changes = True self.__parallel_computing = False self.__update_mode = 'intersection' + self.__allow_ground_rules = False @property def verbose(self) -> bool: @@ -167,6 +189,14 @@ def update_mode(self) -> str: """ return self.__update_mode + @property + def allow_ground_rules(self) -> bool: + """Returns whether rules can have ground atoms or not. Default is False + + :return: bool + """ + return self.__allow_ground_rules + @verbose.setter def verbose(self, value: bool) -> None: """Set verbose mode. Default is True @@ -354,10 +384,23 @@ def update_mode(self, value: str) -> None: else: self.__update_mode = value + @allow_ground_rules.setter + def allow_ground_rules(self, value: bool) -> None: + """Allow ground atoms to be used in rules when possible. Default is False + + :param value: Whether to allow ground atoms or not + :raises TypeError: If not bool raise error + """ + if not isinstance(value, bool): + raise TypeError('value has to be a bool') + else: + self.__allow_ground_rules = value + # VARIABLES __graph = None __rules = None +__clause_maps = None __node_facts = None __edge_facts = None __ipl = None @@ -399,6 +442,22 @@ def reset_rules(): __rules = None +def reset_graph(): + """ + Resets graph to none + """ + global __graph + __graph = None + + +def reset_settings(): + """ + Resets settings to default + """ + global settings + settings.reset() + + # FUNCTIONS def load_graphml(path: str) -> None: """Loads graph from GraphMl file path into program @@ -451,6 +510,18 @@ def load_inconsistent_predicate_list(path: str) -> None: __ipl = yaml_parser.parse_ipl(path) +def add_inconsistent_predicate(pred1: str, pred2: str) -> None: + """Add an inconsistent predicate pair to the IPL + + :param pred1: First predicate in the inconsistent pair + :param pred2: Second predicate in the inconsistent pair + """ + global __ipl + if __ipl is None: + __ipl = numba.typed.List.empty_list(numba.types.Tuple((label.label_type, label.label_type))) + __ipl.append((label.Label(pred1), label.Label(pred2))) + + def add_rule(pr_rule: Rule) -> None: """Add a rule to pyreason from text format. This format is not as modular as the YAML format. """ @@ -459,6 +530,11 @@ def add_rule(pr_rule: Rule) -> None: # Add to collection of rules if __rules is None: __rules = numba.typed.List.empty_list(rule.rule_type) + + # Generate name for rule if not set + if pr_rule.rule.get_rule_name() is None: + pr_rule.rule.set_rule_name(f'rule_{len(__rules)}') + __rules.append(pr_rule.rule) @@ -487,16 +563,20 @@ def add_fact(pyreason_fact: Fact) -> None: """ global __node_facts, __edge_facts + if __node_facts is None: + __node_facts = numba.typed.List.empty_list(fact_node.fact_type) + if __edge_facts is None: + __edge_facts = numba.typed.List.empty_list(fact_edge.fact_type) + if pyreason_fact.type == 'node': - f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.label, pyreason_fact.interval, pyreason_fact.t_lower, pyreason_fact.t_upper, pyreason_fact.static) - if __node_facts is None: - __node_facts = numba.typed.List.empty_list(fact_node.fact_type) + if pyreason_fact.name is None: + pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' + f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) __node_facts.append(f) - else: - f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.label, pyreason_fact.interval, pyreason_fact.t_lower, pyreason_fact.t_upper, pyreason_fact.static) - if __edge_facts is None: - __edge_facts = numba.typed.List.empty_list(fact_edge.fact_type) + if pyreason_fact.name is None: + pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' + f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) __edge_facts.append(f) @@ -554,7 +634,7 @@ def reason(timesteps: int=-1, convergence_threshold: int=-1, convergence_bound_t def _reason(timesteps, convergence_threshold, convergence_bound_threshold): # Globals - global __graph, __rules, __node_facts, __edge_facts, __ipl, __node_labels, __edge_labels, __specific_node_labels, __specific_edge_labels, __graphml_parser + global __graph, __rules, __clause_maps, __node_facts, __edge_facts, __ipl, __node_labels, __edge_labels, __specific_node_labels, __specific_edge_labels, __graphml_parser global settings, __timestamp, __program # Assert variables are of correct type @@ -564,7 +644,9 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold): # Check variables that HAVE to be set. Exceptions if __graph is None: - raise Exception('Graph not loaded. Use `load_graph` to load the graphml file') + load_graph(nx.DiGraph()) + if settings.verbose: + warnings.warn('Graph not loaded. Use `load_graph` to load the graphml file. Using empty graph') if __rules is None: raise Exception('There are no rules, use `add_rule` or `add_rules_from_file`') @@ -610,8 +692,19 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold): # Convert list of annotation functions into tuple to be numba compatible annotation_functions = tuple(__annotation_functions) + # Optimize rules by moving clauses around, only if there are more edges than nodes in the graph + __clause_maps = {r.get_rule_name(): {i: i for i in range(len(r.get_clauses()))} for r in __rules} + if len(__graph.edges) > len(__graph.nodes): + if settings.verbose: + print('Optimizing rules by moving node clauses ahead of edge clauses') + __rules_copy = __rules.copy() + __rules = numba.typed.List.empty_list(rule.rule_type) + for i, r in enumerate(__rules_copy): + r, __clause_maps[r.get_rule_name()] = reorder_clauses(r) + __rules.append(r) + # Setup logical program - __program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.canonical, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode) + __program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.canonical, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules) __program.available_labels_node = __node_labels __program.available_labels_edge = __edge_labels __program.specific_node_labels = __specific_node_labels @@ -651,11 +744,11 @@ def save_rule_trace(interpretation, folder: str='./'): :param interpretation: the output of `pyreason.reason()`, the final interpretation :param folder: the folder in which to save the result, defaults to './' """ - global __timestamp, settings + global __timestamp, __clause_maps, settings assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace' - output = Output(__timestamp) + output = Output(__timestamp, __clause_maps) output.save_rule_trace(interpretation, folder) @@ -667,11 +760,11 @@ def get_rule_trace(interpretation) -> Tuple[pd.DataFrame, pd.DataFrame]: :param interpretation: the output of `pyreason.reason()`, the final interpretation :returns two pandas dataframes (nodes, edges) representing the changes that occurred during reasoning """ - global __timestamp, settings + global __timestamp, __clause_maps, settings assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace' - output = Output(__timestamp) + output = Output(__timestamp, __clause_maps) return output.get_rule_trace(interpretation) diff --git a/pyreason/scripts/facts/fact.py b/pyreason/scripts/facts/fact.py index 1004cec..44d823d 100644 --- a/pyreason/scripts/facts/fact.py +++ b/pyreason/scripts/facts/fact.py @@ -1,23 +1,15 @@ -import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval +import pyreason.scripts.utils.fact_parser as fact_parser import pyreason.scripts.numba_wrapper.numba_types.label_type as label -from typing import Tuple -from typing import List -from typing import Union - class Fact: - def __init__(self, name: str, component: Union[str, Tuple[str, str]], attribute: str, bound: Union[interval.Interval, List[float]], start_time: int, end_time: int, static: bool = False): + def __init__(self, fact_text: str, name: str = None, start_time: int = 0, end_time: int = 0, static: bool = False): """Define a PyReason fact that can be loaded into the program using `pr.add_fact()` + :param fact_text: The fact in text format. Example: `'pred(x,y) : [0.2, 1]'` or `'pred(x,y) : True'` + :type fact_text: str :param name: The name of the fact. This will appear in the trace so that you know when it was applied :type name: str - :param component: The node or edge that whose attribute you want to change - :type component: str | Tuple[str, str] - :param attribute: The attribute you would like to change for the specified node/edge - :type attribute: str - :param bound: The bound to which you'd like to set the attribute corresponding to the specified node/edge - :type bound: interval.Interval | List[float] :param start_time: The timestep at which this fact becomes active :type start_time: int :param end_time: The last timestep this fact is active @@ -25,23 +17,12 @@ def __init__(self, name: str, component: Union[str, Tuple[str, str]], attribute: :param static: If the fact should be active for the entire program. In which case `start_time` and `end_time` will be ignored :type static: bool """ + pred, component, bound, fact_type = fact_parser.parse_fact(fact_text) self.name = name - self.t_upper = end_time - self.t_lower = start_time - self.component = component - self.label = attribute - self.interval = bound + self.start_time = start_time + self.end_time = end_time self.static = static - - # Check if it is a node fact or edge fact - if isinstance(self.component, str): - self.type = 'node' - else: - self.type = 'edge' - - # Set label to correct type - self.label = label.Label(attribute) - - # Set bound to correct type - if isinstance(bound, list): - self.interval = interval.closed(*bound) + self.pred = label.Label(pred) + self.component = component + self.bound = bound + self.type = fact_type diff --git a/pyreason/scripts/facts/fact_edge.py b/pyreason/scripts/facts/fact_edge.py index 935d40f..bbeb3e6 100755 --- a/pyreason/scripts/facts/fact_edge.py +++ b/pyreason/scripts/facts/fact_edge.py @@ -12,6 +12,9 @@ def __init__(self, name, component, label, interval, t_lower, t_upper, static=Fa def get_name(self): return self._name + def set_name(self, name): + self._name = name + def get_component(self): return self._component diff --git a/pyreason/scripts/facts/fact_node.py b/pyreason/scripts/facts/fact_node.py index 92e97c8..69e379e 100755 --- a/pyreason/scripts/facts/fact_node.py +++ b/pyreason/scripts/facts/fact_node.py @@ -12,6 +12,9 @@ def __init__(self, name, component, label, interval, t_lower, t_upper, static=Fa def get_name(self): return self._name + def set_name(self, name): + self._name = name + def get_component(self): return self._component diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py index 779ce68..858620e 100755 --- a/pyreason/scripts/interpretation/interpretation.py +++ b/pyreason/scripts/interpretation/interpretation.py @@ -1,3 +1,5 @@ +from networkx.classes import edges + import pyreason.scripts.numba_wrapper.numba_types.world_type as world import pyreason.scripts.numba_wrapper.numba_types.label_type as label import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval @@ -15,6 +17,12 @@ list_of_nodes = numba.types.ListType(node_type) list_of_edges = numba.types.ListType(edge_type) +# Type for storing clause data +clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string))) + +# Type for storing refine clause data +refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8)) + # Type for facts to be applied facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) @@ -37,6 +45,11 @@ numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) )) +rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) +rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) +rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)) +edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) + class Interpretation: available_labels_node = [] @@ -44,7 +57,7 @@ class Interpretation: specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type)) specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type)) - def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode): + def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules): self.graph = graph self.ipl = ipl self.annotation_functions = annotation_functions @@ -55,18 +68,19 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, self.inconsistency_check = inconsistency_check self.store_interpretation_changes = store_interpretation_changes self.update_mode = update_mode + self.allow_ground_rules = allow_ground_rules # For reasoning and reasoning again (contains previous time and previous fp operation cnt) self.time = 0 self.prev_reasoning_data = numba.typed.List([0, 0]) # Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true - self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))) - self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))) + self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) + self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string) self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string) - self.rules_to_be_applied_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))) - self.rules_to_be_applied_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))) + self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type) + self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type) self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type) self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type) self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))) @@ -94,8 +108,8 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, else: self.available_labels_edge = numba.typed.List(self.available_labels_edge) - self.interpretations_node = self._init_interpretations_node(self.nodes, self.available_labels_node, self.specific_node_labels) - self.interpretations_edge = self._init_interpretations_edge(self.edges, self.available_labels_edge, self.specific_edge_labels) + self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.available_labels_node, self.specific_node_labels) + self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.available_labels_edge, self.specific_edge_labels) # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) @@ -126,6 +140,7 @@ def _init_reverse_neighbors(neighbors): @numba.njit(cache=True) def _init_interpretations_node(nodes, available_labels, specific_labels): interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type) + predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes) # General labels for n in nodes: interpretations[n] = world.World(available_labels) @@ -134,12 +149,19 @@ def _init_interpretations_node(nodes, available_labels, specific_labels): for n in ns: interpretations[n].world[l] = interval.closed(0.0, 1.0) - return interpretations - + for l in available_labels: + predicate_map[l] = numba.typed.List(nodes) + + for l, ns in specific_labels.items(): + predicate_map[l] = numba.typed.List(ns) + + return interpretations, predicate_map + @staticmethod @numba.njit(cache=True) def _init_interpretations_edge(edges, available_labels, specific_labels): interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type) + predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges) # General labels for e in edges: interpretations[e] = world.World(available_labels) @@ -148,8 +170,14 @@ def _init_interpretations_edge(edges, available_labels, specific_labels): for e in es: interpretations[e].world[l] = interval.closed(0.0, 1.0) - return interpretations - + for l in available_labels: + predicate_map[l] = numba.typed.List(edges) + + for l, es in specific_labels.items(): + predicate_map[l] = numba.typed.List(es) + + return interpretations, predicate_map + @staticmethod @numba.njit(cache=True) def _init_convergence(convergence_bound_threshold, convergence_threshold): @@ -193,7 +221,7 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap return max_time def _start_fp(self, rules, max_facts_time, verbose, again): - fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again) + fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again) self.time = t - 1 # If we need to reason again, store the next timestep to start from self.prev_reasoning_data[0] = t @@ -202,15 +230,16 @@ def _start_fp(self, rules, max_facts_time, verbose, again): print('Fixed Point iterations:', fp_cnt) @staticmethod - @numba.njit(cache=True) - def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again): + @numba.njit(cache=True, parallel=False) + def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again): t = prev_reasoning_data[0] fp_cnt = prev_reasoning_data[1] max_rules_time = 0 timestep_loop = True facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type) facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type) - rules_to_remove_idx = numba.typed.List.empty_list(numba.types.int64) + rules_to_remove_idx = set() + rules_to_remove_idx.add(-1) while timestep_loop: if t==tmax: timestep_loop = False @@ -238,24 +267,18 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data bound_delta = 0 update = False - # Parameters for immediate rules - immediate_node_rule_fire = False - immediate_edge_rule_fire = False - immediate_rule_applied = False - # When delta_t = 0, we don't want to check the same rule with the same node/edge after coming back to the fp operator - nodes_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_nodes) - edges_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_edges) - # Initialize the above - for i in range(len(rules)): - nodes_to_skip[i] = numba.typed.List.empty_list(node_type) - edges_to_skip[i] = numba.typed.List.empty_list(edge_type) - # Start by applying facts # Nodes facts_to_be_applied_node_new.clear() + nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): - if facts_to_be_applied_node[i][0]==t: + if facts_to_be_applied_node[i][0] == t: comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + # If the component is not in the graph, add it + if comp not in nodes_set: + _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) + nodes_set.add(comp) + # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): # Check if we should even store any of the changes to the rule trace etc. @@ -273,13 +296,13 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) - + else: # Check for inconsistencies (multiple facts) if check_consistent_node(interpretations_node, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -289,11 +312,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data changes_cnt += changes # Resolve inconsistency if necessary otherwise override bounds else: + mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes) + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -315,9 +338,15 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Edges facts_to_be_applied_edge_new.clear() + edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0]==t: comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + # If the component is not in the graph, add it + if comp not in edges_set: + _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge) + edges_set.add(comp) + # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute @@ -339,7 +368,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data if check_consistent_edge(interpretations_edge, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -349,11 +378,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data changes_cnt += changes # Resolve inconsistency else: + mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -382,50 +411,25 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Nodes rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_node): - # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied - if immediate_node_rule_fire and rules_to_be_applied_node[-1][4]: - i = rules_to_be_applied_node[-1] - idx = len(rules_to_be_applied_node) - 1 - - if i[0]==t: + if i[0] == t: comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5] - sources, targets, edge_l = edges_to_be_added_node_rule[idx] - edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge) - changes_cnt += changes - - # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally - if edge_l.value!='': - for e in edges_added: - if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): - override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) - - update = u or update - - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - # Resolve inconsistency - else: - if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) - else: - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) - - update = u or update + # Check for inconsistencies + if check_consistent_node(interpretations_node, comp, (l, bnd)): + override = True if update_mode == 'override' else False + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + # Resolve inconsistency else: - # Check for inconsistencies - if check_consistent_node(interpretations_node, comp, (l, bnd)): - override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) + if inconsistency_check: + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + else: + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -433,32 +437,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data bound_delta = max(bound_delta, changes) else: changes_cnt += changes - # Resolve inconsistency - else: - if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes) - else: - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes # Delete rules that have been applied from list by adding index to list - rules_to_remove_idx.append(idx) - - # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back - if immediate: - # If delta_t=0 we want to apply one rule and go back to the fp operator - # If delta_t>0 we want to come back here and apply the rest of the rules - if immediate_edge_rule_fire: - break - elif not immediate_edge_rule_fire and u: - immediate_rule_applied = True - break + rules_to_remove_idx.add(idx) # Remove from rules to be applied and edges to be applied lists after coming out from loop rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx]) @@ -469,26 +450,20 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Edges rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): - # If we broke from above loop to apply more rules, then break from here - if immediate_rule_applied and not immediate_edge_rule_fire: - break - # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied - if immediate_edge_rule_fire and rules_to_be_applied_edge[-1][4]: - i = rules_to_be_applied_edge[-1] - idx = len(rules_to_be_applied_edge) - 1 - - if i[0]==t: + if i[0] == t: comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] - edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge) + edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge) changes_cnt += changes # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally - if edge_l.value!='': + if edge_l.value != '': for e in edges_added: + if interpretations_edge[e].world[edge_l].is_static(): + continue if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update @@ -500,9 +475,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) + resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update @@ -516,7 +491,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Check for inconsistencies if check_consistent_edge(interpretations_edge, comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update # Update convergence params @@ -527,9 +502,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -539,17 +514,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data changes_cnt += changes # Delete rules that have been applied from list by adding the index to list - rules_to_remove_idx.append(idx) - - # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back - if immediate: - # If t=0 we want to apply one rule and go back to the fp operator - # If t>0 we want to come back here and apply the rest of the rules - if immediate_edge_rule_fire: - break - elif not immediate_edge_rule_fire and u: - immediate_rule_applied = True - break + rules_to_remove_idx.add(idx) # Remove from rules to be applied and edges to be applied lists after coming out from loop rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx]) @@ -560,59 +525,45 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Fixed point # if update or immediate_node_rule_fire or immediate_edge_rule_fire or immediate_rule_applied: if update: - # Increase fp operator count only if not an immediate rule - if not (immediate_node_rule_fire or immediate_edge_rule_fire): - fp_cnt += 1 + # Increase fp operator count + fp_cnt += 1 - for i in range(len(rules)): + # Lists or threadsafe operations (when parallel is on) + rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))]) + rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))]) + if atom_trace: + rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) + rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) + edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))]) + + for i in prange(len(rules)): rule = rules[i] immediate_rule = rule.is_immediate_rule() - immediate_node_rule_fire = False - immediate_edge_rule_fire = False # Only go through if the rule can be applied within the given timesteps, or we're running until convergence delta_t = rule.get_delta() if t + delta_t <= tmax or tmax == -1 or again: - applicable_node_rules = _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip[i]) - applicable_edge_rules = _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip[i]) + applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules) # Loop through applicable rules and add them to the rules to be applied for later or next fp operation for applicable_rule in applicable_node_rules: - n, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule + n, annotations, qualified_nodes, qualified_edges, _ = applicable_rule # If there is an edge to add or the predicate doesn't exist or the interpretation is not static - if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static(): + if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static(): bnd = annotate(annotation_functions, rule, annotations, rule.get_weights()) # Bound annotations in between 0 and 1 bnd_l = min(max(bnd[0], 0), 1) bnd_u = min(max(bnd[1], 0), 1) bnd = interval.closed(bnd_l, bnd_u) max_rules_time = max(max_rules_time, t + delta_t) - edges_to_be_added_node_rule.append(edges_to_add) - rules_to_be_applied_node.append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) + rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) if atom_trace: - rules_to_be_applied_node_trace.append((qualified_nodes, qualified_edges, rule.get_name())) - - # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance - # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0 - if delta_t != 0: - nodes_to_skip[i].append(n) + rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) - # Handle loop parameters for the next (maybe) fp operation - # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire - # Next fp operation we will skip this rule on this node because anyway there won't be an update + # If delta_t is zero we apply the rules and check if more are applicable if delta_t == 0: in_loop = True update = False - if immediate_rule and delta_t == 0: - # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done. - in_loop = True - update = True - immediate_node_rule_fire = True - break - - # Break, apply immediate rule then come back to check for more applicable rules - if immediate_node_rule_fire: - break for applicable_rule in applicable_edge_rules: e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule @@ -624,51 +575,44 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data bnd_u = min(max(bnd[1], 0), 1) bnd = interval.closed(bnd_l, bnd_u) max_rules_time = max(max_rules_time, t+delta_t) - edges_to_be_added_edge_rule.append(edges_to_add) - rules_to_be_applied_edge.append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) + # edges_to_be_added_edge_rule.append(edges_to_add) + edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add) + # rules_to_be_applied_edge.append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) + rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) if atom_trace: - rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name())) + # rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name())) + rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) - # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance - # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0 - if delta_t != 0: - edges_to_skip[i].append(e) - - # Handle loop parameters for the next (maybe) fp operation - # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire - # Next fp operation we will skip this rule on this node because anyway there won't be an update + # If delta_t is zero we apply the rules and check if more are applicable if delta_t == 0: in_loop = True update = False - if immediate_rule and delta_t == 0: - # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done. - in_loop = True - update = True - immediate_edge_rule_fire = True - break - - # Break, apply immediate rule then come back to check for more applicable rules - if immediate_edge_rule_fire: - break - - # Go through all the rules and go back to applying the rules if we came here because of an immediate rule where delta_t>0 - if immediate_rule_applied and not (immediate_node_rule_fire or immediate_edge_rule_fire): - immediate_rule_applied = False - in_loop = True - update = False - continue - + + # Update lists after parallel run + for i in range(len(rules)): + if len(rules_to_be_applied_node_threadsafe[i]) > 0: + rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) + if len(rules_to_be_applied_edge_threadsafe[i]) > 0: + rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) + if atom_trace: + if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: + rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) + if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: + rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) + if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: + edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) + # Check for convergence after each timestep (perfect convergence or convergence specified by user) # Check number of changed interpretations or max bound change # User specified convergence - if convergence_mode=='delta_interpretation': + if convergence_mode == 'delta_interpretation': if changes_cnt <= convergence_delta: if verbose: print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation') # Be consistent with time returned when we don't converge t += 1 break - elif convergence_mode=='delta_bound': + elif convergence_mode == 'delta_bound': if bound_delta <= convergence_delta: if verbose: print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation') @@ -678,8 +622,8 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Perfect convergence # Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable # If no more rules/facts to be applied - elif convergence_mode=='perfect_convergence': - if t>=max_facts_time and t>=max_rules_time: + elif convergence_mode == 'perfect_convergence': + if t>=max_facts_time and t >= max_rules_time: if verbose: print(f'\nConverged at time: {t}') # Be consistent with time returned when we don't converge @@ -693,7 +637,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data def add_edge(self, edge, l): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally @@ -704,19 +648,19 @@ def add_node(self, node, labels): def delete_edge(self, edge): # This function is useful for pyreason gym, called externally - _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge) + _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge) def delete_node(self, node): # This function is useful for pyreason gym, called externally - _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) + _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node) - def get_interpretation_dict(self): + def get_dict(self): # This function can be called externally to retrieve a dict of the interpretation values # Only values in the rule trace will be added # Initialize interpretations for each time and node and edge interpretations = {} - for t in range(self.tmax+1): + for t in range(self.time+1): interpretations[t] = {} for node in self.nodes: interpretations[t][node] = InterpretationDict() @@ -730,7 +674,7 @@ def get_interpretation_dict(self): # If canonical, update all following timesteps as well if self. canonical: - for t in range(time+1, self.tmax+1): + for t in range(time+1, self.time+1): interpretations[t][node][l._value] = (bnd.lower, bnd.upper) # Update interpretation edges @@ -740,11 +684,524 @@ def get_interpretation_dict(self): # If canonical, update all following timesteps as well if self. canonical: - for t in range(time+1, self.tmax+1): + for t in range(time+1, self.time+1): interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) return interpretations + def query(self, query, return_bool=True): + """ + This function is used to query the graph after reasoning + :param query: The query string of for `pred(node)` or `pred(edge)` or `pred(node) : [l, u]` + :param return_bool: If True, returns boolean of query, else the bounds associated with it + :return: bool, or bounds + """ + # Parse the query + query = query.replace(' ', '') + + if ':' in query: + pred_comp, bounds = query.split(':') + bounds = bounds.replace('[', '').replace(']', '') + l, u = bounds.split(',') + l, u = float(l), float(u) + else: + if query[0] == '~': + pred_comp = query[1:] + l, u = 0, 0 + else: + pred_comp = query + l, u = 1, 1 + + bnd = interval.closed(l, u) + + # Split predicate and component + idx = pred_comp.find('(') + pred = label.Label(pred_comp[:idx]) + component = pred_comp[idx + 1:-1] + + if ',' in component: + component = tuple(component.split(',')) + comp_type = 'edge' + else: + comp_type = 'node' + + # Check if the component exists + if comp_type == 'node': + if component not in self.nodes: + return False if return_bool else (0, 0) + else: + if component not in self.edges: + return False if return_bool else (0, 0) + + # Check if the predicate exists + if comp_type == 'node': + if pred not in self.interpretations_node[component].world: + return False if return_bool else (0, 0) + else: + if pred not in self.interpretations_edge[component].world: + return False if return_bool else (0, 0) + + # Check if the bounds are satisfied + if comp_type == 'node': + if self.interpretations_node[component].world[pred] in bnd: + return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper) + else: + return False if return_bool else (0, 0) + else: + if self.interpretations_edge[component].world[pred] in bnd: + return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper) + else: + return False if return_bool else (0, 0) + + +@numba.njit(cache=True) +def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules): + # Extract rule params + rule_type = rule.get_type() + head_variables = rule.get_head_variables() + clauses = rule.get_clauses() + thresholds = rule.get_thresholds() + ann_fn = rule.get_annotation_function() + rule_edges = rule.get_edges() + + if rule_type == 'node': + head_var_1 = head_variables[0] + else: + head_var_1, head_var_2 = head_variables[0], head_variables[1] + + # We return a list of tuples which specify the target nodes/edges that have made the rule body true + applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type) + applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type) + + # Grounding procedure + # 1. Go through each clause and check which variables have not been initialized in groundings + # 2. Check satisfaction of variables based on the predicate in the clause + + # Grounding variable that maps variables in the body to a list of grounded nodes + # Grounding edges that maps edge variables to a list of edges + groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes) + groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges) + + # Dependency graph that keeps track of the connections between the variables in the body + dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) + dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) + + nodes_set = set(nodes) + edges_set = set(edges) + + satisfaction = True + for i, clause in enumerate(clauses): + # Unpack clause variables + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + clause_bnd = clause[3] + clause_operator = clause[4] + + # This is a node clause + if clause_type == 'node': + clause_var_1 = clause_variables[0] + + # Get subset of nodes that can be used to ground the variable + # If we allow ground atoms, we can use the nodes directly + if allow_ground_rules and clause_var_1 in nodes_set: + grounding = numba.typed.List([clause_var_1]) + else: + grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label, nodes) + + # Narrow subset based on predicate + qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd) + groundings[clause_var_1] = qualified_groundings + qualified_groundings_set = set(qualified_groundings) + for c1, c2 in groundings_edges: + if c1 == clause_var_1: + groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set]) + if c2 == clause_var_1: + groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set]) + + # Check satisfaction of those nodes wrt the threshold + # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule + # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges + # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: + satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction + + # This is an edge clause + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + + # Get subset of edges that can be used to ground the variables + # If we allow ground atoms, we can use the nodes directly + if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set: + grounding = numba.typed.List([(clause_var_1, clause_var_2)]) + else: + grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges) + + # Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster) + qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd) + + # Check satisfaction of those edges wrt the threshold + # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule + # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges + # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: + satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction + + # Update the groundings + groundings[clause_var_1] = numba.typed.List.empty_list(node_type) + groundings[clause_var_2] = numba.typed.List.empty_list(node_type) + groundings_clause_1_set = set(groundings[clause_var_1]) + groundings_clause_2_set = set(groundings[clause_var_2]) + for e in qualified_groundings: + if e[0] not in groundings_clause_1_set: + groundings[clause_var_1].append(e[0]) + groundings_clause_1_set.add(e[0]) + if e[1] not in groundings_clause_2_set: + groundings[clause_var_2].append(e[1]) + groundings_clause_2_set.add(e[1]) + + # Update the edge groundings (to use later for grounding other clauses with the same variables) + groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings + + # Update dependency graph + # Add a connection between clause_var_1 -> clause_var_2 and vice versa + if clause_var_1 not in dependency_graph_neighbors: + dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2]) + elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]: + dependency_graph_neighbors[clause_var_1].append(clause_var_2) + if clause_var_2 not in dependency_graph_reverse_neighbors: + dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1]) + elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]: + dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1) + + # This is a comparison clause + else: + pass + + # Refine the subsets based on any updates + if satisfaction: + refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) + + # If satisfaction is false, break + if not satisfaction: + break + + # If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules + # Then continue to setup any edges to be added and annotations + # Fill out the rules to be applied lists + if satisfaction: + # Create temp grounding containers to verify if the head groundings are valid (only for edge rules) + # Setup edges to be added and fill rules to be applied + # Setup traces and inputs for annotation function + # Loop through the clause data and setup final annotations and trace variables + # Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule + if rule_type == 'node': + # Loop through all the head variable groundings and add it to the rules to be applied + # Loop through the clauses and add appropriate trace data and annotations + + # If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph + head_var_1_in_nodes = head_var_1 in nodes + add_head_var_node_to_graph = False + if allow_ground_rules and head_var_1_in_nodes: + groundings[head_var_1] = numba.typed.List([head_var_1]) + elif head_var_1 not in groundings: + if not head_var_1_in_nodes: + add_head_var_node_to_graph = True + groundings[head_var_1] = numba.typed.List([head_var_1]) + + for head_grounding in groundings[head_var_1]: + qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) + edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + + # Check for satisfaction one more time in case the refining process has changed the groundings + satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges) + if not satisfaction: + continue + + for i, clause in enumerate(clauses): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + + # 1. + if atom_trace: + if clause_var_1 == head_var_1: + qualified_nodes.append(numba.typed.List([head_grounding])) + else: + qualified_nodes.append(numba.typed.List(groundings[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + a.append(interpretations_node[head_grounding].world[clause_label]) + else: + for qn in groundings[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + # 1. + if atom_trace: + # Cases: Both equal, one equal, none equal + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + if clause_var_1 == head_var_1: + es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_1: + es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding]) + qualified_edges.append(es) + else: + qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)])) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + for e in groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_1: + for e in groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_grounding: + a.append(interpretations_edge[e].world[clause_label]) + else: + for qe in groundings_edges[(clause_var_1, clause_var_2)]: + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + else: + # Comparison clause (we do not handle for now) + pass + + # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) + if add_head_var_node_to_graph: + _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) + + # For each grounding add a rule to be applied + applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) + + elif rule_type == 'edge': + head_var_1 = head_variables[0] + head_var_2 = head_variables[1] + + # If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph + head_var_1_in_nodes = head_var_1 in nodes + head_var_2_in_nodes = head_var_2 in nodes + add_head_var_1_node_to_graph = False + add_head_var_2_node_to_graph = False + add_head_edge_to_graph = False + if allow_ground_rules and head_var_1_in_nodes: + groundings[head_var_1] = numba.typed.List([head_var_1]) + if allow_ground_rules and head_var_2_in_nodes: + groundings[head_var_2] = numba.typed.List([head_var_2]) + + if head_var_1 not in groundings: + if not head_var_1_in_nodes: + add_head_var_1_node_to_graph = True + groundings[head_var_1] = numba.typed.List([head_var_1]) + if head_var_2 not in groundings: + if not head_var_2_in_nodes: + add_head_var_2_node_to_graph = True + groundings[head_var_2] = numba.typed.List([head_var_2]) + + # Artificially connect the head variables with an edge if both of them were not in the graph + if not head_var_1_in_nodes and not head_var_2_in_nodes: + add_head_edge_to_graph = True + + head_var_1_groundings = groundings[head_var_1] + head_var_2_groundings = groundings[head_var_2] + + source, target, _ = rule_edges + infer_edges = True if source != '' and target != '' else False + + # Prepare the edges that we will loop over. + # For infer edges we loop over each combination pair + # Else we loop over the valid edges in the graph + valid_edge_groundings = numba.typed.List.empty_list(edge_type) + for g1 in head_var_1_groundings: + for g2 in head_var_2_groundings: + if infer_edges: + valid_edge_groundings.append((g1, g2)) + else: + if (g1, g2) in edges_set: + valid_edge_groundings.append((g1, g2)) + + # Loop through the head variable groundings + for valid_e in valid_edge_groundings: + head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1] + qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) + edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + + # Containers to keep track of groundings to make sure that the edge pair is valid + # We do this because we cannot know beforehand the edge matches from source groundings to target groundings + temp_groundings = groundings.copy() + temp_groundings_edges = groundings_edges.copy() + + # Refine the temp groundings for the specific edge head grounding + # We update the edge collection as well depending on if there's a match between the clause variables and head variables + temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding]) + temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding]) + for c1, c2 in temp_groundings_edges.keys(): + if c1 == head_var_1 and c2 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)]) + elif c1 == head_var_2 and c2 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)]) + elif c1 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding]) + elif c2 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding]) + elif c1 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding]) + elif c2 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding]) + + refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) + + # Check if the thresholds are still satisfied + # Check if all clauses are satisfied again in case the refining process changed anything + satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges) + + if not satisfaction: + continue + + if infer_edges: + # Prevent self loops while inferring edges if the clause variables are not the same + if source != target and head_var_1_grounding == head_var_2_grounding: + continue + edges_to_be_added[0].append(head_var_1_grounding) + edges_to_be_added[1].append(head_var_2_grounding) + + for i, clause in enumerate(clauses): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + # 1. + if atom_trace: + if clause_var_1 == head_var_1: + qualified_nodes.append(numba.typed.List([head_var_1_grounding])) + elif clause_var_1 == head_var_2: + qualified_nodes.append(numba.typed.List([head_var_2_grounding])) + else: + qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + a.append(interpretations_node[head_var_1_grounding].world[clause_label]) + elif clause_var_1 == head_var_2: + a.append(interpretations_node[head_var_2_grounding].world[clause_label]) + else: + for qn in temp_groundings[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + # 1. + if atom_trace: + # Cases: + # 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1) + # 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2) + # 3. None equal + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding]) + qualified_edges.append(es) + else: + qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)])) + + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + else: + for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]: + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + + # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) + if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1: + _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) + if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2: + _add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node) + if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding): + _add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge) + + # For each grounding combination add a rule to be applied + # Only if all the clauses have valid groundings + # if satisfaction: + e = (head_var_1_grounding, head_var_2_grounding) + applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) + + # Return the applicable rules + return applicable_rules_node, applicable_rules_edge + + +@numba.njit(cache=True) +def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges): + # Check if the thresholds are satisfied for each clause + satisfaction = True + for i, clause in enumerate(clauses): + # Unpack clause variables + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], groundings[clause_var_1], clause_label, thresholds[i]) and satisfaction + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], groundings_edges[(clause_var_1, clause_var_2)], clause_label, thresholds[i]) and satisfaction + return satisfaction + @numba.njit(cache=True) def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip): @@ -757,6 +1214,10 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n # We return a list of tuples which specify the target nodes/edges that have made the rule body true applicable_rules = numba.typed.List.empty_list(node_applicable_rule_type) + + # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe + # One array for each node, then condense into a single list later + applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(node_applicable_rule_type) for _ in nodes]) # Return empty list if rule is not node rule and if we are not inferring edges if rule_type != 'node' and rule_edges[0] == '': @@ -781,6 +1242,7 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + clause_type_and_variables = numba.typed.List.empty_list(clause_data) satisfaction = True for i, clause in enumerate(clauses): @@ -791,28 +1253,16 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n clause_bnd = clause[3] clause_operator = clause[4] - # Unpack thresholds - # This value is total/available - threshold_quantifier_type = thresholds[i][1][1] - # This is a node clause # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes if clause_type == 'node': clause_var_1 = clause_variables[0] - subset = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors) + subset = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes) subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd) - if atom_trace: - qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) - - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in subsets[clause_var_1]: - a.append(interpretations_node[qn].world[clause_label]) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node', clause_label, numba.typed.List([clause_var_1]))) # This is an edge clause elif clause_type == 'edge': @@ -824,16 +1274,9 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n subsets[clause_var_1] = qe[0] subsets[clause_var_2] = qe[1] - if atom_trace: - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) - - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): - a.append(interpretations_edge[qe].world[clause_label]) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) + else: # This is a comparison clause # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1] @@ -848,8 +1291,8 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n # It's a node comparison if len(clause_variables) == 2: clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - subset_1 = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors) - subset_2 = get_node_rule_node_clause_subset(clause_var_2, target_node, subsets, neighbors) + subset_1 = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes) + subset_2 = get_node_rule_node_clause_subset(clause_var_2, target_node, subsets, nodes) # 1, 2 qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd) @@ -877,19 +1320,10 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n # Update subsets with final qualified nodes subsets[clause_var_1] = qualified_nodes_1 subsets[clause_var_2] = qualified_nodes_2 - qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) - qualified_comparison_nodes.extend(qualified_nodes_2) - if atom_trace: - qualified_nodes.append(qualified_comparison_nodes) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node-comparison', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) - # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in qualified_comparison_nodes: - a.append(interval.closed(1, 1)) - annotations.append(a) # Edge comparison. Compare stage else: satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator, @@ -903,39 +1337,19 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n subsets[clause_var_2_source] = qualified_nodes_2_source subsets[clause_var_2_target] = qualified_nodes_2_target - qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target)) - qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target)) - qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1) - qualified_comparison_nodes.extend(qualified_comparison_nodes_2) - - if atom_trace: - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - qualified_edges.append(qualified_comparison_nodes) - - # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qe in qualified_comparison_nodes: - a.append(interval.closed(1, 1)) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge-comparison', clause_label, numba.typed.List([clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target]))) # Non comparison clause else: - if threshold_quantifier_type == 'total': - if clause_type == 'node': - neigh_len = len(subset) - else: - neigh_len = sum([len(l) for l in subset_target]) - - # Available is all neighbors that have a particular label with bound inside [0,1] - elif threshold_quantifier_type == 'available': - if clause_type == 'node': - neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0,1))) - else: - neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0,1), reverse_graph)[0]) + if clause_type == 'node': + satisfaction = check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, thresholds[i]) and satisfaction + else: + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, thresholds[i], reverse_graph) and satisfaction - qualified_neigh_len = len(subsets[clause_var_1]) - satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction + # Refine subsets based on any updates + if satisfaction: + satisfaction = refine_subsets_node_rule(interpretations_edge, clauses, i, subsets, target_node, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph) and satisfaction # Exit loop if even one clause is not satisfied if not satisfaction: @@ -962,8 +1376,87 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n else: edges_to_be_added[1].append(target) + # Loop through the clause data and setup final annotations and trace variables + # 1. Add qualified nodes/edges to trace + # 2. Add annotations to annotation function variable + for i, clause in enumerate(clause_type_and_variables): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in subsets[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + + elif clause_type == 'node-comparison': + clause_var_1, clause_var_2 = clause_variables + qualified_nodes_1 = subsets[clause_var_1] + qualified_nodes_2 = subsets[clause_var_2] + qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) + qualified_comparison_nodes.extend(qualified_nodes_2) + # 1. + if atom_trace: + qualified_nodes.append(qualified_comparison_nodes) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in qualified_comparison_nodes: + a.append(interval.closed(1, 1)) + annotations.append(a) + + elif clause_type == 'edge-comparison': + clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables + qualified_nodes_1_source = subsets[clause_var_1_source] + qualified_nodes_1_target = subsets[clause_var_1_target] + qualified_nodes_2_source = subsets[clause_var_2_source] + qualified_nodes_2_target = subsets[clause_var_2_target] + qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target)) + qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target)) + qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1) + qualified_comparison_nodes.extend(qualified_comparison_nodes_2) + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + qualified_edges.append(qualified_comparison_nodes) + # 2. + # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qe in qualified_comparison_nodes: + a.append(interval.closed(1, 1)) + annotations.append(a) + # node/edge, annotations, qualified nodes, qualified edges, edges to be added - applicable_rules.append((target_node, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) + applicable_rules_threadsafe[piter] = numba.typed.List([(target_node, annotations, qualified_nodes, qualified_edges, edges_to_be_added)]) + + # Merge all threadsafe rules into one single array + for applicable_rule in applicable_rules_threadsafe: + if len(applicable_rule) > 0: + applicable_rules.append(applicable_rule[0]) return applicable_rules @@ -979,6 +1472,10 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e # We return a list of tuples which specify the target nodes/edges that have made the rule body true applicable_rules = numba.typed.List.empty_list(edge_applicable_rule_type) + + # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe + # One array for each node, then condense into a single list later + applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(edge_applicable_rule_type) for _ in edges]) # Return empty list if rule is not node rule if rule_type != 'edge': @@ -1003,6 +1500,7 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + clause_type_and_variables = numba.typed.List.empty_list(clause_data) satisfaction = True for i, clause in enumerate(clauses): @@ -1013,27 +1511,16 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e clause_bnd = clause[3] clause_operator = clause[4] - # Unpack thresholds - # This value is total/available - threshold_quantifier_type = thresholds[i][1][1] - # This is a node clause # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes if clause_type == 'node': clause_var_1 = clause_variables[0] - subset = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors) - + subset = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes) + subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd) - if atom_trace: - qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in subsets[clause_var_1]: - a.append(interpretations_node[qn].world[clause_label]) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node', clause_label, numba.typed.List([clause_var_1]))) # This is an edge clause elif clause_type == 'edge': @@ -1045,17 +1532,9 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e subsets[clause_var_1] = qe[0] subsets[clause_var_2] = qe[1] - if atom_trace: - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) - - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): - a.append(interpretations_edge[qe].world[clause_label]) - annotations.append(a) - + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) + else: # This is a comparison clause # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1] @@ -1066,17 +1545,17 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e # 2. get qualified nodes/edges as well as number associated for second predicate # 3. if there's no number in steps 1 or 2 return false clause # 4. do comparison with each qualified component from step 1 with each qualified component in step 2 - + # It's a node comparison if len(clause_variables) == 2: clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - subset_1 = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors) - subset_2 = get_edge_rule_node_clause_subset(clause_var_2, target_edge, subsets, neighbors) - + subset_1 = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes) + subset_2 = get_edge_rule_node_clause_subset(clause_var_2, target_edge, subsets, nodes) + # 1, 2 qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd) qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd) - + # It's an edge comparison elif len(clause_variables) == 4: clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3] @@ -1099,19 +1578,10 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e # Update subsets with final qualified nodes subsets[clause_var_1] = qualified_nodes_1 subsets[clause_var_2] = qualified_nodes_2 - qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) - qualified_comparison_nodes.extend(qualified_nodes_2) - if atom_trace: - qualified_nodes.append(qualified_comparison_nodes) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node-comparison', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) - # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in qualified_comparison_nodes: - a.append(interval.closed(1, 1)) - annotations.append(a) # Edge comparison. Compare stage else: satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator, @@ -1125,89 +1595,464 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e subsets[clause_var_2_source] = qualified_nodes_2_source subsets[clause_var_2_target] = qualified_nodes_2_target + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge-comparison', clause_label, numba.typed.List([clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target]))) + + # Non comparison clause + else: + if clause_type == 'node': + satisfaction = check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, thresholds[i]) and satisfaction + else: + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, thresholds[i], reverse_graph) and satisfaction + + # Refine subsets based on any updates + if satisfaction: + satisfaction = refine_subsets_edge_rule(interpretations_edge, clauses, i, subsets, target_edge, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph) and satisfaction + + # Exit loop if even one clause is not satisfied + if not satisfaction: + break + + # Here we are done going through each clause of the rule + # If all clauses we're satisfied, proceed to collect annotations and prepare edges to be added + if satisfaction: + # Loop through the clause data and setup final annotations and trace variables + # 1. Add qualified nodes/edges to trace + # 2. Add annotations to annotation function variable + for i, clause in enumerate(clause_type_and_variables): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in subsets[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + + elif clause_type == 'node-comparison': + clause_var_1, clause_var_2 = clause_variables + qualified_nodes_1 = subsets[clause_var_1] + qualified_nodes_2 = subsets[clause_var_2] + qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) + qualified_comparison_nodes.extend(qualified_nodes_2) + # 1. + if atom_trace: + qualified_nodes.append(qualified_comparison_nodes) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in qualified_comparison_nodes: + a.append(interval.closed(1, 1)) + annotations.append(a) + + elif clause_type == 'edge-comparison': + clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables + qualified_nodes_1_source = subsets[clause_var_1_source] + qualified_nodes_1_target = subsets[clause_var_1_target] + qualified_nodes_2_source = subsets[clause_var_2_source] + qualified_nodes_2_target = subsets[clause_var_2_target] qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target)) qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target)) qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1) qualified_comparison_nodes.extend(qualified_comparison_nodes_2) - + # 1. if atom_trace: qualified_nodes.append(numba.typed.List.empty_list(node_type)) qualified_edges.append(qualified_comparison_nodes) - + # 2. # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations if ann_fn != '': a = numba.typed.List.empty_list(interval.interval_type) for qe in qualified_comparison_nodes: a.append(interval.closed(1, 1)) annotations.append(a) - - # Non comparison clause - else: - if threshold_quantifier_type == 'total': - if clause_type == 'node': - neigh_len = len(subset) - else: - neigh_len = sum([len(l) for l in subset_target]) - - # Available is all neighbors that have a particular label with bound inside [0,1] - elif threshold_quantifier_type == 'available': - if clause_type == 'node': - neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0, 1))) - else: - neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0, 1), reverse_graph)[0]) - - qualified_neigh_len = len(subsets[clause_var_1]) - satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction - - # Exit loop if even one clause is not satisfied - if not satisfaction: - break + # node/edge, annotations, qualified nodes, qualified edges, edges to be added + applicable_rules_threadsafe[piter] = numba.typed.List([(target_edge, annotations, qualified_nodes, qualified_edges, edges_to_be_added)]) - # Here we are done going through each clause of the rule - # If all clauses we're satisfied, proceed to collect annotations and prepare edges to be added - if satisfaction: - # Collect edges to be added - source, target, _ = rule_edges + # Merge all threadsafe rules into one single array + for applicable_rule in applicable_rules_threadsafe: + if len(applicable_rule) > 0: + applicable_rules.append(applicable_rule[0]) - # Edges to be added - if source != '' and target != '': - # Check if edge nodes are source/target - if source == '__source': - edges_to_be_added[0].append(target_edge[0]) - elif source == '__target': - edges_to_be_added[0].append(target_edge[1]) - elif source in subsets: - edges_to_be_added[0].extend(subsets[source]) - else: - edges_to_be_added[0].append(source) + return applicable_rules - if target == '__source': - edges_to_be_added[1].append(target_edge[0]) - elif target == '__target': - edges_to_be_added[1].append(target_edge[1]) - elif target in subsets: - edges_to_be_added[1].extend(subsets[target]) - else: - edges_to_be_added[1].append(target) - # node/edge, annotations, qualified nodes, qualified edges, edges to be added - applicable_rules.append((target_edge, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) +@numba.njit(cache=True) +def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors): + # Loop through the dependency graph and refine the groundings that have connections + all_variables_refined = numba.typed.List(clause_variables) + variables_just_refined = numba.typed.List(clause_variables) + new_variables_refined = numba.typed.List.empty_list(numba.types.string) + while len(variables_just_refined) > 0: + for refined_variable in variables_just_refined: + # Refine all the neighbors of the refined variable + if refined_variable in dependency_graph_neighbors: + for neighbor in dependency_graph_neighbors[refined_variable]: + old_edge_groundings = groundings_edges[(refined_variable, neighbor)] + new_node_groundings = groundings[refined_variable] + + # Delete old groundings for the variable being refined + del groundings[neighbor] + groundings[neighbor] = numba.typed.List.empty_list(node_type) + + # Update the edge groundings and node groundings + qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings]) + groundings_neighbor_set = set(groundings[neighbor]) + for e in qualified_groundings: + if e[1] not in groundings_neighbor_set: + groundings[neighbor].append(e[1]) + groundings_neighbor_set.add(e[1]) + groundings_edges[(refined_variable, neighbor)] = qualified_groundings + + # Add the neighbor to the list of refined variables so that we can refine for all its neighbors + if neighbor not in all_variables_refined: + new_variables_refined.append(neighbor) + + if refined_variable in dependency_graph_reverse_neighbors: + for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]: + old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)] + new_node_groundings = groundings[refined_variable] + + # Delete old groundings for the variable being refined + del groundings[reverse_neighbor] + groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type) + + # Update the edge groundings and node groundings + qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings]) + groundings_reverse_neighbor_set = set(groundings[reverse_neighbor]) + for e in qualified_groundings: + if e[0] not in groundings_reverse_neighbor_set: + groundings[reverse_neighbor].append(e[0]) + groundings_reverse_neighbor_set.add(e[0]) + groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings + + # Add the neighbor to the list of refined variables so that we can refine for all its neighbors + if reverse_neighbor not in all_variables_refined: + new_variables_refined.append(reverse_neighbor) + + variables_just_refined = numba.typed.List(new_variables_refined) + all_variables_refined.extend(new_variables_refined) + new_variables_refined.clear() + + +@numba.njit(cache=True) +def refine_subsets_node_rule(interpretations_edge, clauses, i, subsets, target_node, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph): + """NOTE: DEPRECATED""" + # Loop through all clauses till clause i-1 and update subsets recursively + # Then check if the clause still satisfies the thresholds + clause = clauses[i] + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + clause_bnd = clause[3] + clause_operator = clause[4] + + # Keep track of the variables that were refined (start with clause_variables) and variables that need refining + satisfaction = True + all_variables_refined = numba.typed.List(clause_variables) + variables_just_refined = numba.typed.List(clause_variables) + new_variables_refined = numba.typed.List.empty_list(numba.types.string) + while len(variables_just_refined) > 0: + for j in range(i): + c = clauses[j] + c_type = c[0] + c_label = c[1] + c_variables = c[2] + c_bnd = c[3] + c_operator = c[4] + + # If it is an edge clause or edge comparison clause, check if any of clause_variables are in c_variables + # If yes, then update the variable that is with it in the clause + if c_type == 'edge' or (c_type == 'comparison' and len(c_variables) > 2): + for v in variables_just_refined: + for k, cv in enumerate(c_variables): + if cv == v: + # Find which variable needs to be refined, 1st or 2nd. + # 2nd variable needs refining + if k == 0: + refine_idx = 1 + refine_v = c_variables[1] + # 1st variable needs refining + elif k == 1: + refine_idx = 0 + refine_v = c_variables[0] + # 2nd variable needs refining + elif k == 2: + refine_idx = 1 + refine_v = c_variables[3] + # 1st variable needs refining + else: + refine_idx = 0 + refine_v = c_variables[2] + + # Refine the variable + if refine_v not in all_variables_refined: + new_variables_refined.append(refine_v) + + if c_type == 'edge': + clause_var_1, clause_var_2 = (refine_v, cv) if refine_idx == 0 else (cv, refine_v) + del subsets[refine_v] + subset_source, subset_target = get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes) + + # Get qualified edges + qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, c_label, c_bnd, reverse_graph) + subsets[clause_var_1] = qe[0] + subsets[clause_var_2] = qe[1] + + # Check if we still satisfy the clause + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, c_label, thresholds[j], reverse_graph) and satisfaction + else: + # We do not support refinement for comparison clauses + pass + + if not satisfaction: + return satisfaction + + variables_just_refined = numba.typed.List(new_variables_refined) + all_variables_refined.extend(new_variables_refined) + new_variables_refined.clear() + + return satisfaction + + +@numba.njit(cache=True) +def refine_subsets_edge_rule(interpretations_edge, clauses, i, subsets, target_edge, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph): + """NOTE: DEPRECATED""" + # Loop through all clauses till clause i-1 and update subsets recursively + # Then check if the clause still satisfies the thresholds + clause = clauses[i] + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + clause_bnd = clause[3] + clause_operator = clause[4] + + # Keep track of the variables that were refined (start with clause_variables) and variables that need refining + satisfaction = True + all_variables_refined = numba.typed.List(clause_variables) + variables_just_refined = numba.typed.List(clause_variables) + new_variables_refined = numba.typed.List.empty_list(numba.types.string) + while len(variables_just_refined) > 0: + for j in range(i): + c = clauses[j] + c_type = c[0] + c_label = c[1] + c_variables = c[2] + c_bnd = c[3] + c_operator = c[4] + + # If it is an edge clause or edge comparison clause, check if any of clause_variables are in c_variables + # If yes, then update the variable that is with it in the clause + if c_type == 'edge' or (c_type == 'comparison' and len(c_variables) > 2): + for v in variables_just_refined: + for k, cv in enumerate(c_variables): + if cv == v: + # Find which variable needs to be refined, 1st or 2nd. + # 2nd variable needs refining + if k == 0: + refine_idx = 1 + refine_v = c_variables[1] + # 1st variable needs refining + elif k == 1: + refine_idx = 0 + refine_v = c_variables[0] + # 2nd variable needs refining + elif k == 2: + refine_idx = 1 + refine_v = c_variables[3] + # 1st variable needs refining + else: + refine_idx = 0 + refine_v = c_variables[2] + + # Refine the variable + if refine_v not in all_variables_refined: + new_variables_refined.append(refine_v) + + if c_type == 'edge': + clause_var_1, clause_var_2 = (refine_v, cv) if refine_idx == 0 else (cv, refine_v) + del subsets[refine_v] + subset_source, subset_target = get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes) + + # Get qualified edges + qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, c_label, c_bnd, reverse_graph) + subsets[clause_var_1] = qe[0] + subsets[clause_var_2] = qe[1] + + # Check if we still satisfy the clause + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, c_label, thresholds[j], reverse_graph) and satisfaction + else: + # We do not support refinement for comparison clauses + pass + + if not satisfaction: + return satisfaction + + variables_just_refined = numba.typed.List(new_variables_refined) + all_variables_refined.extend(new_variables_refined) + new_variables_refined.clear() + + return satisfaction + + +@numba.njit(cache=True) +def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold): + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(grounding) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1))) + + qualified_neigh_len = len(qualified_grounding) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold): + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(grounding) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1))) + + qualified_neigh_len = len(qualified_grounding) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, threshold): + """NOTE: DEPRECATED""" + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(subset) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0, 1))) + + # Only take length of clause_var_1 because length of subsets of var_1 and var_2 are supposed to be equal + qualified_neigh_len = len(subsets[clause_var_1]) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, threshold, reverse_graph): + """NOTE: DEPRECATED""" + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = sum([len(l) for l in subset_target]) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0, 1), reverse_graph)[0]) + + qualified_neigh_len = len(subsets[clause_var_1]) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction - return applicable_rules + +@numba.njit(cache=True) +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): + # The groundings for a node clause can be either a previous grounding or all possible nodes + if l in predicate_map: + grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] + else: + grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] + return grounding + + +@numba.njit(cache=True) +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): + # There are 4 cases for predicate(Y,Z): + # 1. Both predicate variables Y and Z have not been encountered before + # 2. The source variable Y has not been encountered before but the target variable Z has + # 3. The target variable Z has not been encountered before but the source variable Y has + # 4. Both predicate variables Y and Z have been encountered before + edge_groundings = numba.typed.List.empty_list(edge_type) + + # Case 1: + # We replace Y by all nodes and Z by the neighbors of each of these nodes + if clause_var_1 not in groundings and clause_var_2 not in groundings: + if l in predicate_map: + edge_groundings = predicate_map[l] + else: + edge_groundings = edges + + # Case 2: + # We replace Y by the sources of Z + elif clause_var_1 not in groundings and clause_var_2 in groundings: + for n in groundings[clause_var_2]: + es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]]) + edge_groundings.extend(es) + + # Case 3: + # We replace Z by the neighbors of Y + elif clause_var_1 in groundings and clause_var_2 not in groundings: + for n in groundings[clause_var_1]: + es = numba.typed.List([(n, nn) for nn in neighbors[n]]) + edge_groundings.extend(es) + + # Case 4: + # We have seen both variables before + else: + # We have already seen these two variables in an edge clause + if (clause_var_1, clause_var_2) in groundings_edges: + edge_groundings = groundings_edges[(clause_var_1, clause_var_2)] + # We have seen both these variables but not in an edge clause together + else: + groundings_clause_var_2_set = set(groundings[clause_var_2]) + for n in groundings[clause_var_1]: + es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set]) + edge_groundings.extend(es) + + return edge_groundings @numba.njit(cache=True) -def get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors): +def get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes): + """NOTE: DEPRECATED""" # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes if clause_var_1 == '__target': subset = numba.typed.List([target_node]) else: - subset = neighbors[target_node] if clause_var_1 not in subsets else subsets[clause_var_1] + nodes_without_target = numba.typed.List([n for n in nodes if n != target_node]) + subset = nodes_without_target if clause_var_1 not in subsets else subsets[clause_var_1] return subset @numba.njit(cache=True) def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes): + """NOTE: DEPRECATED""" # There are 5 cases for predicate(Y,Z): # 1. Either one or both of Y, Z are the target node # 2. Both predicate variables Y and Z have not been encountered before @@ -1234,10 +2079,10 @@ def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, su subset_target = numba.typed.List([numba.typed.List([target_node]) for _ in subset_source]) # Case 2: - # We replace Y by all nodes and Z by the neighbors of each of these nodes + # We replace Y by all nodes (except target_node) and Z by the neighbors of each of these nodes elif clause_var_1 not in subsets and clause_var_2 not in subsets: - subset_source = numba.typed.List(nodes) - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_source = numba.typed.List([n for n in nodes if n != target_node]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_node]) for n in subset_source]) # Case 3: # We replace Y by the sources of Z @@ -1248,37 +2093,50 @@ def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, su for n in subsets[clause_var_2]: sources = reverse_neighbors[n] for source in sources: - subset_source.append(source) - subset_target.append(numba.typed.List([n])) + if source != target_node: + subset_source.append(source) + subset_target.append(numba.typed.List([n])) # Case 4: # We replace Z by the neighbors of Y elif clause_var_1 in subsets and clause_var_2 not in subsets: subset_source = subsets[clause_var_1] - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_node]) for n in subset_source]) # Case 5: else: subset_source = subsets[clause_var_1] subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source]) + # If any of the subsets are empty return them in the correct type + if len(subset_source) == 0: + subset_source = numba.typed.List.empty_list(node_type) + subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + # If any sub lists in subset target are empty, add correct type for empty list + for i, t in enumerate(subset_target): + if len(t) == 0: + subset_target[i] = numba.typed.List.empty_list(node_type) + return subset_source, subset_target @numba.njit(cache=True) -def get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors): +def get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes): + """NOTE: DEPRECATED""" # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes if clause_var_1 == '__source': subset = numba.typed.List([target_edge[0]]) elif clause_var_1 == '__target': subset = numba.typed.List([target_edge[1]]) else: - subset = neighbors[target_edge[0]] if clause_var_1 not in subsets else subsets[clause_var_1] + nodes_without_target_or_source = numba.typed.List([n for n in nodes if n != target_edge[0] and n != target_edge[1]]) + subset = nodes_without_target_or_source if clause_var_1 not in subsets else subsets[clause_var_1] return subset @numba.njit(cache=True) def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes): + """NOTE: DEPRECATED""" # There are 5 cases for predicate(Y,Z): # 1. Either one or both of Y, Z are the source or target node # 2. Both predicate variables Y and Z have not been encountered before @@ -1324,10 +2182,10 @@ def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, su subset_target = numba.typed.List([numba.typed.List([target_edge[1]]) for _ in subset_source]) # Case 2: - # We replace Y by all nodes and Z by the neighbors of each of these nodes + # We replace Y by all nodes (except source/target) and Z by the neighbors of each of these nodes elif clause_var_1 not in subsets and clause_var_2 not in subsets: - subset_source = numba.typed.List(nodes) - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_source = numba.typed.List([n for n in nodes if n != target_edge[0] and n != target_edge[1]]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_edge[0] and nn != target_edge[1]]) for n in subset_source]) # Case 3: # We replace Y by the sources of Z @@ -1338,29 +2196,62 @@ def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, su for n in subsets[clause_var_2]: sources = reverse_neighbors[n] for source in sources: - subset_source.append(source) - subset_target.append(numba.typed.List([n])) + if source != target_edge[0] and source != target_edge[1]: + subset_source.append(source) + subset_target.append(numba.typed.List([n])) # Case 4: # We replace Z by the neighbors of Y elif clause_var_1 in subsets and clause_var_2 not in subsets: subset_source = subsets[clause_var_1] - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_edge[0] and nn != target_edge[1]]) for n in subset_source]) # Case 5: else: subset_source = subsets[clause_var_1] subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source]) + # If any of the subsets are empty return them in the correct type + if len(subset_source) == 0: + subset_source = numba.typed.List.empty_list(node_type) + subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + # If any sub lists in subset target are empty, add correct type for empty list + for i, t in enumerate(subset_target): + if len(t) == 0: + subset_target[i] = numba.typed.List.empty_list(node_type) + return subset_source, subset_target +@numba.njit(cache=True) +def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd): + # Filter the grounding by the predicate and bound of the clause + qualified_groundings = numba.typed.List.empty_list(node_type) + for n in grounding: + if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd)): + qualified_groundings.append(n) + + return qualified_groundings + + +@numba.njit(cache=True) +def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd): + # Filter the grounding by the predicate and bound of the clause + qualified_groundings = numba.typed.List.empty_list(edge_type) + for e in grounding: + if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd)): + qualified_groundings.append(e) + + return qualified_groundings + + @numba.njit(cache=True) def get_qualified_components_node_clause(interpretations_node, candidates, l, bnd): + """NOTE: DEPRECATED""" # Get all the qualified neighbors for a particular clause qualified_nodes = numba.typed.List.empty_list(node_type) for n in candidates: - if is_satisfied_node(interpretations_node, n, (l, bnd)): + if is_satisfied_node(interpretations_node, n, (l, bnd)) and n not in qualified_nodes: qualified_nodes.append(n) return qualified_nodes @@ -1368,6 +2259,7 @@ def get_qualified_components_node_clause(interpretations_node, candidates, l, bn @numba.njit(cache=True) def get_qualified_components_node_comparison_clause(interpretations_node, candidates, l, bnd): + """NOTE: DEPRECATED""" # Get all the qualified neighbors for a particular comparison clause and return them along with the number associated qualified_nodes = numba.typed.List.empty_list(node_type) qualified_nodes_numbers = numba.typed.List.empty_list(numba.types.float64) @@ -1382,6 +2274,7 @@ def get_qualified_components_node_comparison_clause(interpretations_node, candid @numba.njit(cache=True) def get_qualified_components_edge_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph): + """NOTE: DEPRECATED""" # Get all the qualified sources and targets for a particular clause qualified_nodes_source = numba.typed.List.empty_list(node_type) qualified_nodes_target = numba.typed.List.empty_list(node_type) @@ -1397,6 +2290,7 @@ def get_qualified_components_edge_clause(interpretations_edge, candidates_source @numba.njit(cache=True) def get_qualified_components_edge_comparison_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph): + """NOTE: DEPRECATED""" # Get all the qualified sources and targets for a particular clause qualified_nodes_source = numba.typed.List.empty_list(node_type) qualified_nodes_target = numba.typed.List.empty_list(node_type) @@ -1415,6 +2309,7 @@ def get_qualified_components_edge_comparison_clause(interpretations_edge, candid @numba.njit(cache=True) def compare_numbers_node_predicate(numbers_1, numbers_2, op, qualified_nodes_1, qualified_nodes_2): + """NOTE: DEPRECATED""" result = False final_qualified_nodes_1 = numba.typed.List.empty_list(node_type) final_qualified_nodes_2 = numba.typed.List.empty_list(node_type) @@ -1447,6 +2342,7 @@ def compare_numbers_node_predicate(numbers_1, numbers_2, op, qualified_nodes_1, @numba.njit(cache=True) def compare_numbers_edge_predicate(numbers_1, numbers_2, op, qualified_nodes_1a, qualified_nodes_1b, qualified_nodes_2a, qualified_nodes_2b): + """NOTE: DEPRECATED""" result = False final_qualified_nodes_1a = numba.typed.List.empty_list(node_type) final_qualified_nodes_1b = numba.typed.List.empty_list(node_type) @@ -1514,7 +2410,7 @@ def _satisfies_threshold(num_neigh, num_qualified_component, threshold): @numba.njit(cache=True) -def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): +def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): updated = False # This is to prevent a key error in case the label is a specific label try: @@ -1525,6 +2421,10 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat # Add label to world if it is not there if l not in world.world: world.world[l] = interval.closed(0, 1) + if l in predicate_map: + predicate_map[l].append(comp) + else: + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd prev_bnd = world.world[l].copy() @@ -1557,7 +2457,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1==l: + if p1 == l: + if p2 not in world.world: + world.world[p2] = interval.closed(0, 1) + if p2 in predicate_map: + predicate_map[p2].append(comp) + else: + predicate_map[p2] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) @@ -1568,7 +2474,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2==l: + if p2 == l: + if p1 not in world.world: + world.world[p1] = interval.closed(0, 1) + if p1 in predicate_map: + predicate_map[p1].append(comp) + else: + predicate_map[p1] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) @@ -1603,7 +2515,7 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat @numba.njit(cache=True) -def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): +def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): updated = False # This is to prevent a key error in case the label is a specific label try: @@ -1614,6 +2526,10 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat # Add label to world if it is not there if l not in world.world: world.world[l] = interval.closed(0, 1) + if l in predicate_map: + predicate_map[l].append(comp) + else: + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd prev_bnd = world.world[l].copy() @@ -1646,7 +2562,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1==l: + if p1 == l: + if p2 not in world.world: + world.world[p2] = interval.closed(0, 1) + if p2 in predicate_map: + predicate_map[p2].append(comp) + else: + predicate_map[p2] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) @@ -1657,7 +2579,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2==l: + if p2 == l: + if p1 not in world.world: + world.world[p1] = interval.closed(0, 1) + if p1 in predicate_map: + predicate_map[p1].append(comp) + else: + predicate_map[p1] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) @@ -1668,7 +2596,7 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper))) - + # Gather convergence data change = 0 if updated: @@ -1684,7 +2612,7 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat change = max(change, max_delta) else: change = 1 + ip_update_cnt - + return (updated, change) except: return (False, 0) @@ -1693,20 +2621,20 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat @numba.njit(cache=True) def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): rule_trace.append((qn, qe, prev_bnd.copy(), name)) - + @numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (label, interval) in nas: - result = result and is_satisfied_node(interpretations, comp, (label, interval)) + for (l, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (l, bnd)) return result @numba.njit(cache=True) def is_satisfied_node(interpretations, comp, na): result = False - if (not (na[0] is None or na[1] is None)): + if not (na[0] is None or na[1] is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] @@ -1748,15 +2676,15 @@ def is_satisfied_node_comparison(interpretations, comp, na): @numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (label, interval) in nas: - result = result and is_satisfied_edge(interpretations, comp, (label, interval)) + for (l, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) return result @numba.njit(cache=True) def is_satisfied_edge(interpretations, comp, na): result = False - if (not (na[0] is None or na[1] is None)): + if not (na[0] is None or na[1] is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] @@ -1835,19 +2763,25 @@ def check_consistent_edge(interpretations, comp, na): @numba.njit(cache=True) -def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes): +def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): world = interpretations[comp] if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) + if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: + name = facts_to_be_applied_trace[idx] + elif mode == 'rule' and atom_trace: + name = rules_to_be_applied_trace[idx][2] + else: + name = '-' if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], f'Inconsistency due to {name}') # Resolve inconsistency and set static world.world[na[0]].set_lower_upper(0, 1) world.world[na[0]].set_static(True) for p1, p2 in ipl: if p1==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'Inconsistency due to {name}') world.world[p2].set_lower_upper(0, 1) world.world[p2].set_static(True) if store_interpretation_changes: @@ -1855,28 +2789,34 @@ def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, at if p2==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'Inconsistency due to {name}') world.world[p1].set_lower_upper(0, 1) world.world[p1].set_static(True) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1))) - # Add inconsistent predicates to a list + # Add inconsistent predicates to a list @numba.njit(cache=True) -def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes): +def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): w = interpretations[comp] if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) + if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: + name = facts_to_be_applied_trace[idx] + elif mode == 'rule' and atom_trace: + name = rules_to_be_applied_trace[idx][2] + else: + name = '-' if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], f'Inconsistency due to {name}') # Resolve inconsistency and set static w.world[na[0]].set_lower_upper(0, 1) w.world[na[0]].set_static(True) for p1, p2 in ipl: if p1==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], f'Inconsistency due to {name}') w.world[p2].set_lower_upper(0, 1) w.world[p2].set_static(True) if store_interpretation_changes: @@ -1884,7 +2824,7 @@ def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, at if p2==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], f'Inconsistency due to {name}') w.world[p1].set_lower_upper(0, 1) w.world[p1].set_static(True) if store_interpretation_changes: @@ -1900,7 +2840,7 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): @numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge): +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1920,6 +2860,10 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int reverse_neighbors[target].append(source) if l.value!='': interpretations_edge[edge] = world.World(numba.typed.List([l])) + if l in predicate_map: + predicate_map[l].append(edge) + else: + predicate_map[l] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: @@ -1931,32 +2875,38 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int @numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge): +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes @numba.njit(cache=True) -def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge): +def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map): source, target = edge edges.remove(edge) del interpretations_edge[edge] + for l in predicate_map: + if edge in predicate_map[l]: + predicate_map[l].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) @numba.njit(cache=True) -def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): +def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map): nodes.remove(node) del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] + for l in predicate_map: + if node in predicate_map[l]: + predicate_map[l].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py index 230b641..fd0f582 100644 --- a/pyreason/scripts/interpretation/interpretation_parallel.py +++ b/pyreason/scripts/interpretation/interpretation_parallel.py @@ -1,3 +1,5 @@ +from networkx.classes import edges + import pyreason.scripts.numba_wrapper.numba_types.world_type as world import pyreason.scripts.numba_wrapper.numba_types.label_type as label import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval @@ -15,6 +17,12 @@ list_of_nodes = numba.types.ListType(node_type) list_of_edges = numba.types.ListType(edge_type) +# Type for storing clause data +clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string))) + +# Type for storing refine clause data +refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8)) + # Type for facts to be applied facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) @@ -37,6 +45,11 @@ numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) )) +rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) +rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) +rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)) +edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) + class Interpretation: available_labels_node = [] @@ -44,7 +57,7 @@ class Interpretation: specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type)) specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type)) - def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode): + def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules): self.graph = graph self.ipl = ipl self.annotation_functions = annotation_functions @@ -55,18 +68,19 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, self.inconsistency_check = inconsistency_check self.store_interpretation_changes = store_interpretation_changes self.update_mode = update_mode + self.allow_ground_rules = allow_ground_rules # For reasoning and reasoning again (contains previous time and previous fp operation cnt) self.time = 0 self.prev_reasoning_data = numba.typed.List([0, 0]) # Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true - self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))) - self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))) + self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) + self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string) self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string) - self.rules_to_be_applied_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))) - self.rules_to_be_applied_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))) + self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type) + self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type) self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type) self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type) self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))) @@ -94,8 +108,8 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, else: self.available_labels_edge = numba.typed.List(self.available_labels_edge) - self.interpretations_node = self._init_interpretations_node(self.nodes, self.available_labels_node, self.specific_node_labels) - self.interpretations_edge = self._init_interpretations_edge(self.edges, self.available_labels_edge, self.specific_edge_labels) + self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.available_labels_node, self.specific_node_labels) + self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.available_labels_edge, self.specific_edge_labels) # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) @@ -107,7 +121,7 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) @staticmethod - @numba.njit(cache=False) + @numba.njit(cache=True) def _init_reverse_neighbors(neighbors): reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) for n, neighbor_nodes in neighbors.items(): @@ -123,9 +137,10 @@ def _init_reverse_neighbors(neighbors): return reverse_neighbors @staticmethod - @numba.njit(cache=False) + @numba.njit(cache=True) def _init_interpretations_node(nodes, available_labels, specific_labels): interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type) + predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes) # General labels for n in nodes: interpretations[n] = world.World(available_labels) @@ -134,12 +149,19 @@ def _init_interpretations_node(nodes, available_labels, specific_labels): for n in ns: interpretations[n].world[l] = interval.closed(0.0, 1.0) - return interpretations - + for l in available_labels: + predicate_map[l] = numba.typed.List(nodes) + + for l, ns in specific_labels.items(): + predicate_map[l] = numba.typed.List(ns) + + return interpretations, predicate_map + @staticmethod - @numba.njit(cache=False) + @numba.njit(cache=True) def _init_interpretations_edge(edges, available_labels, specific_labels): interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type) + predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges) # General labels for e in edges: interpretations[e] = world.World(available_labels) @@ -148,10 +170,16 @@ def _init_interpretations_edge(edges, available_labels, specific_labels): for e in es: interpretations[e].world[l] = interval.closed(0.0, 1.0) - return interpretations - + for l in available_labels: + predicate_map[l] = numba.typed.List(edges) + + for l, es in specific_labels.items(): + predicate_map[l] = numba.typed.List(es) + + return interpretations, predicate_map + @staticmethod - @numba.njit(cache=False) + @numba.njit(cache=True) def _init_convergence(convergence_bound_threshold, convergence_threshold): if convergence_bound_threshold==-1 and convergence_threshold==-1: convergence_mode = 'perfect_convergence' @@ -171,7 +199,7 @@ def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_thr self._start_fp(rules, max_facts_time, verbose, again) @staticmethod - @numba.njit(cache=False) + @numba.njit(cache=True) def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, atom_trace): max_time = 0 for fact in facts_node: @@ -193,7 +221,7 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap return max_time def _start_fp(self, rules, max_facts_time, verbose, again): - fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again) + fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again) self.time = t - 1 # If we need to reason again, store the next timestep to start from self.prev_reasoning_data[0] = t @@ -202,15 +230,16 @@ def _start_fp(self, rules, max_facts_time, verbose, again): print('Fixed Point iterations:', fp_cnt) @staticmethod - @numba.njit(cache=False) - def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again): + @numba.njit(cache=True, parallel=True) + def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again): t = prev_reasoning_data[0] fp_cnt = prev_reasoning_data[1] max_rules_time = 0 timestep_loop = True facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type) facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type) - rules_to_remove_idx = numba.typed.List.empty_list(numba.types.int64) + rules_to_remove_idx = set() + rules_to_remove_idx.add(-1) while timestep_loop: if t==tmax: timestep_loop = False @@ -238,24 +267,18 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data bound_delta = 0 update = False - # Parameters for immediate rules - immediate_node_rule_fire = False - immediate_edge_rule_fire = False - immediate_rule_applied = False - # When delta_t = 0, we don't want to check the same rule with the same node/edge after coming back to the fp operator - nodes_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_nodes) - edges_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_edges) - # Initialize the above - for i in range(len(rules)): - nodes_to_skip[i] = numba.typed.List.empty_list(node_type) - edges_to_skip[i] = numba.typed.List.empty_list(edge_type) - # Start by applying facts # Nodes facts_to_be_applied_node_new.clear() + nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): - if facts_to_be_applied_node[i][0]==t: + if facts_to_be_applied_node[i][0] == t: comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + # If the component is not in the graph, add it + if comp not in nodes_set: + _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) + nodes_set.add(comp) + # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): # Check if we should even store any of the changes to the rule trace etc. @@ -273,13 +296,13 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) - + else: # Check for inconsistencies (multiple facts) if check_consistent_node(interpretations_node, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -289,11 +312,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data changes_cnt += changes # Resolve inconsistency if necessary otherwise override bounds else: + mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes) + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -315,9 +338,15 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Edges facts_to_be_applied_edge_new.clear() + edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0]==t: comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + # If the component is not in the graph, add it + if comp not in edges_set: + _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge) + edges_set.add(comp) + # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute @@ -339,7 +368,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data if check_consistent_edge(interpretations_edge, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -349,11 +378,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data changes_cnt += changes # Resolve inconsistency else: + mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -382,50 +411,25 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Nodes rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_node): - # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied - if immediate_node_rule_fire and rules_to_be_applied_node[-1][4]: - i = rules_to_be_applied_node[-1] - idx = len(rules_to_be_applied_node) - 1 - - if i[0]==t: + if i[0] == t: comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5] - sources, targets, edge_l = edges_to_be_added_node_rule[idx] - edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge) - changes_cnt += changes - - # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally - if edge_l.value!='': - for e in edges_added: - if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): - override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) - - update = u or update - - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - # Resolve inconsistency - else: - if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) - else: - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) - - update = u or update + # Check for inconsistencies + if check_consistent_node(interpretations_node, comp, (l, bnd)): + override = True if update_mode == 'override' else False + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + # Resolve inconsistency else: - # Check for inconsistencies - if check_consistent_node(interpretations_node, comp, (l, bnd)): - override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) + if inconsistency_check: + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + else: + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -433,32 +437,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data bound_delta = max(bound_delta, changes) else: changes_cnt += changes - # Resolve inconsistency - else: - if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes) - else: - u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes # Delete rules that have been applied from list by adding index to list - rules_to_remove_idx.append(idx) - - # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back - if immediate: - # If delta_t=0 we want to apply one rule and go back to the fp operator - # If delta_t>0 we want to come back here and apply the rest of the rules - if immediate_edge_rule_fire: - break - elif not immediate_edge_rule_fire and u: - immediate_rule_applied = True - break + rules_to_remove_idx.add(idx) # Remove from rules to be applied and edges to be applied lists after coming out from loop rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx]) @@ -469,26 +450,20 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Edges rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): - # If we broke from above loop to apply more rules, then break from here - if immediate_rule_applied and not immediate_edge_rule_fire: - break - # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied - if immediate_edge_rule_fire and rules_to_be_applied_edge[-1][4]: - i = rules_to_be_applied_edge[-1] - idx = len(rules_to_be_applied_edge) - 1 - - if i[0]==t: + if i[0] == t: comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] - edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge) + edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge) changes_cnt += changes # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally - if edge_l.value!='': + if edge_l.value != '': for e in edges_added: + if interpretations_edge[e].world[edge_l].is_static(): + continue if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update @@ -500,9 +475,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) + resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update @@ -516,7 +491,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Check for inconsistencies if check_consistent_edge(interpretations_edge, comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update # Update convergence params @@ -527,9 +502,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes) + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -539,17 +514,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data changes_cnt += changes # Delete rules that have been applied from list by adding the index to list - rules_to_remove_idx.append(idx) - - # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back - if immediate: - # If t=0 we want to apply one rule and go back to the fp operator - # If t>0 we want to come back here and apply the rest of the rules - if immediate_edge_rule_fire: - break - elif not immediate_edge_rule_fire and u: - immediate_rule_applied = True - break + rules_to_remove_idx.add(idx) # Remove from rules to be applied and edges to be applied lists after coming out from loop rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx]) @@ -560,59 +525,45 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Fixed point # if update or immediate_node_rule_fire or immediate_edge_rule_fire or immediate_rule_applied: if update: - # Increase fp operator count only if not an immediate rule - if not (immediate_node_rule_fire or immediate_edge_rule_fire): - fp_cnt += 1 + # Increase fp operator count + fp_cnt += 1 - for i in range(len(rules)): + # Lists or threadsafe operations (when parallel is on) + rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))]) + rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))]) + if atom_trace: + rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) + rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) + edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))]) + + for i in prange(len(rules)): rule = rules[i] immediate_rule = rule.is_immediate_rule() - immediate_node_rule_fire = False - immediate_edge_rule_fire = False # Only go through if the rule can be applied within the given timesteps, or we're running until convergence delta_t = rule.get_delta() if t + delta_t <= tmax or tmax == -1 or again: - applicable_node_rules = _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip[i]) - applicable_edge_rules = _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip[i]) + applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules) # Loop through applicable rules and add them to the rules to be applied for later or next fp operation for applicable_rule in applicable_node_rules: - n, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule + n, annotations, qualified_nodes, qualified_edges, _ = applicable_rule # If there is an edge to add or the predicate doesn't exist or the interpretation is not static - if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static(): + if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static(): bnd = annotate(annotation_functions, rule, annotations, rule.get_weights()) # Bound annotations in between 0 and 1 bnd_l = min(max(bnd[0], 0), 1) bnd_u = min(max(bnd[1], 0), 1) bnd = interval.closed(bnd_l, bnd_u) max_rules_time = max(max_rules_time, t + delta_t) - edges_to_be_added_node_rule.append(edges_to_add) - rules_to_be_applied_node.append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) + rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) if atom_trace: - rules_to_be_applied_node_trace.append((qualified_nodes, qualified_edges, rule.get_name())) + rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) - # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance - # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0 - if delta_t != 0: - nodes_to_skip[i].append(n) - - # Handle loop parameters for the next (maybe) fp operation - # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire - # Next fp operation we will skip this rule on this node because anyway there won't be an update + # If delta_t is zero we apply the rules and check if more are applicable if delta_t == 0: in_loop = True update = False - if immediate_rule and delta_t == 0: - # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done. - in_loop = True - update = True - immediate_node_rule_fire = True - break - - # Break, apply immediate rule then come back to check for more applicable rules - if immediate_node_rule_fire: - break for applicable_rule in applicable_edge_rules: e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule @@ -624,51 +575,44 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data bnd_u = min(max(bnd[1], 0), 1) bnd = interval.closed(bnd_l, bnd_u) max_rules_time = max(max_rules_time, t+delta_t) - edges_to_be_added_edge_rule.append(edges_to_add) - rules_to_be_applied_edge.append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) + # edges_to_be_added_edge_rule.append(edges_to_add) + edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add) + # rules_to_be_applied_edge.append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) + rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule())) if atom_trace: - rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name())) + # rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name())) + rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) - # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance - # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0 - if delta_t != 0: - edges_to_skip[i].append(e) - - # Handle loop parameters for the next (maybe) fp operation - # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire - # Next fp operation we will skip this rule on this node because anyway there won't be an update + # If delta_t is zero we apply the rules and check if more are applicable if delta_t == 0: in_loop = True update = False - if immediate_rule and delta_t == 0: - # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done. - in_loop = True - update = True - immediate_edge_rule_fire = True - break - - # Break, apply immediate rule then come back to check for more applicable rules - if immediate_edge_rule_fire: - break - - # Go through all the rules and go back to applying the rules if we came here because of an immediate rule where delta_t>0 - if immediate_rule_applied and not (immediate_node_rule_fire or immediate_edge_rule_fire): - immediate_rule_applied = False - in_loop = True - update = False - continue - + + # Update lists after parallel run + for i in range(len(rules)): + if len(rules_to_be_applied_node_threadsafe[i]) > 0: + rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) + if len(rules_to_be_applied_edge_threadsafe[i]) > 0: + rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) + if atom_trace: + if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: + rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) + if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: + rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) + if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: + edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) + # Check for convergence after each timestep (perfect convergence or convergence specified by user) # Check number of changed interpretations or max bound change # User specified convergence - if convergence_mode=='delta_interpretation': + if convergence_mode == 'delta_interpretation': if changes_cnt <= convergence_delta: if verbose: print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation') # Be consistent with time returned when we don't converge t += 1 break - elif convergence_mode=='delta_bound': + elif convergence_mode == 'delta_bound': if bound_delta <= convergence_delta: if verbose: print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation') @@ -678,8 +622,8 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data # Perfect convergence # Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable # If no more rules/facts to be applied - elif convergence_mode=='perfect_convergence': - if t>=max_facts_time and t>=max_rules_time: + elif convergence_mode == 'perfect_convergence': + if t>=max_facts_time and t >= max_rules_time: if verbose: print(f'\nConverged at time: {t}') # Be consistent with time returned when we don't converge @@ -693,7 +637,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data def add_edge(self, edge, l): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally @@ -704,19 +648,19 @@ def add_node(self, node, labels): def delete_edge(self, edge): # This function is useful for pyreason gym, called externally - _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge) + _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge) def delete_node(self, node): # This function is useful for pyreason gym, called externally - _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) + _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node) - def get_interpretation_dict(self): + def get_dict(self): # This function can be called externally to retrieve a dict of the interpretation values # Only values in the rule trace will be added # Initialize interpretations for each time and node and edge interpretations = {} - for t in range(self.tmax+1): + for t in range(self.time+1): interpretations[t] = {} for node in self.nodes: interpretations[t][node] = InterpretationDict() @@ -730,7 +674,7 @@ def get_interpretation_dict(self): # If canonical, update all following timesteps as well if self. canonical: - for t in range(time+1, self.tmax+1): + for t in range(time+1, self.time+1): interpretations[t][node][l._value] = (bnd.lower, bnd.upper) # Update interpretation edges @@ -740,13 +684,526 @@ def get_interpretation_dict(self): # If canonical, update all following timesteps as well if self. canonical: - for t in range(time+1, self.tmax+1): + for t in range(time+1, self.time+1): interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) return interpretations + def query(self, query, return_bool=True): + """ + This function is used to query the graph after reasoning + :param query: The query string of for `pred(node)` or `pred(edge)` or `pred(node) : [l, u]` + :param return_bool: If True, returns boolean of query, else the bounds associated with it + :return: bool, or bounds + """ + # Parse the query + query = query.replace(' ', '') + + if ':' in query: + pred_comp, bounds = query.split(':') + bounds = bounds.replace('[', '').replace(']', '') + l, u = bounds.split(',') + l, u = float(l), float(u) + else: + if query[0] == '~': + pred_comp = query[1:] + l, u = 0, 0 + else: + pred_comp = query + l, u = 1, 1 + + bnd = interval.closed(l, u) + + # Split predicate and component + idx = pred_comp.find('(') + pred = label.Label(pred_comp[:idx]) + component = pred_comp[idx + 1:-1] + + if ',' in component: + component = tuple(component.split(',')) + comp_type = 'edge' + else: + comp_type = 'node' + + # Check if the component exists + if comp_type == 'node': + if component not in self.nodes: + return False if return_bool else (0, 0) + else: + if component not in self.edges: + return False if return_bool else (0, 0) + + # Check if the predicate exists + if comp_type == 'node': + if pred not in self.interpretations_node[component].world: + return False if return_bool else (0, 0) + else: + if pred not in self.interpretations_edge[component].world: + return False if return_bool else (0, 0) + + # Check if the bounds are satisfied + if comp_type == 'node': + if self.interpretations_node[component].world[pred] in bnd: + return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper) + else: + return False if return_bool else (0, 0) + else: + if self.interpretations_edge[component].world[pred] in bnd: + return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper) + else: + return False if return_bool else (0, 0) + -@numba.njit(cache=False, parallel=True) +@numba.njit(cache=True) +def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules): + # Extract rule params + rule_type = rule.get_type() + head_variables = rule.get_head_variables() + clauses = rule.get_clauses() + thresholds = rule.get_thresholds() + ann_fn = rule.get_annotation_function() + rule_edges = rule.get_edges() + + if rule_type == 'node': + head_var_1 = head_variables[0] + else: + head_var_1, head_var_2 = head_variables[0], head_variables[1] + + # We return a list of tuples which specify the target nodes/edges that have made the rule body true + applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type) + applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type) + + # Grounding procedure + # 1. Go through each clause and check which variables have not been initialized in groundings + # 2. Check satisfaction of variables based on the predicate in the clause + + # Grounding variable that maps variables in the body to a list of grounded nodes + # Grounding edges that maps edge variables to a list of edges + groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes) + groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges) + + # Dependency graph that keeps track of the connections between the variables in the body + dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) + dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) + + nodes_set = set(nodes) + edges_set = set(edges) + + satisfaction = True + for i, clause in enumerate(clauses): + # Unpack clause variables + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + clause_bnd = clause[3] + clause_operator = clause[4] + + # This is a node clause + if clause_type == 'node': + clause_var_1 = clause_variables[0] + + # Get subset of nodes that can be used to ground the variable + # If we allow ground atoms, we can use the nodes directly + if allow_ground_rules and clause_var_1 in nodes_set: + grounding = numba.typed.List([clause_var_1]) + else: + grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label, nodes) + + # Narrow subset based on predicate + qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd) + groundings[clause_var_1] = qualified_groundings + qualified_groundings_set = set(qualified_groundings) + for c1, c2 in groundings_edges: + if c1 == clause_var_1: + groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set]) + if c2 == clause_var_1: + groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set]) + + # Check satisfaction of those nodes wrt the threshold + # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule + # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges + # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: + satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction + + # This is an edge clause + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + + # Get subset of edges that can be used to ground the variables + # If we allow ground atoms, we can use the nodes directly + if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set: + grounding = numba.typed.List([(clause_var_1, clause_var_2)]) + else: + grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges) + + # Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster) + qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd) + + # Check satisfaction of those edges wrt the threshold + # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule + # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges + # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: + satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction + + # Update the groundings + groundings[clause_var_1] = numba.typed.List.empty_list(node_type) + groundings[clause_var_2] = numba.typed.List.empty_list(node_type) + groundings_clause_1_set = set(groundings[clause_var_1]) + groundings_clause_2_set = set(groundings[clause_var_2]) + for e in qualified_groundings: + if e[0] not in groundings_clause_1_set: + groundings[clause_var_1].append(e[0]) + groundings_clause_1_set.add(e[0]) + if e[1] not in groundings_clause_2_set: + groundings[clause_var_2].append(e[1]) + groundings_clause_2_set.add(e[1]) + + # Update the edge groundings (to use later for grounding other clauses with the same variables) + groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings + + # Update dependency graph + # Add a connection between clause_var_1 -> clause_var_2 and vice versa + if clause_var_1 not in dependency_graph_neighbors: + dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2]) + elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]: + dependency_graph_neighbors[clause_var_1].append(clause_var_2) + if clause_var_2 not in dependency_graph_reverse_neighbors: + dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1]) + elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]: + dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1) + + # This is a comparison clause + else: + pass + + # Refine the subsets based on any updates + if satisfaction: + refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) + + # If satisfaction is false, break + if not satisfaction: + break + + # If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules + # Then continue to setup any edges to be added and annotations + # Fill out the rules to be applied lists + if satisfaction: + # Create temp grounding containers to verify if the head groundings are valid (only for edge rules) + # Setup edges to be added and fill rules to be applied + # Setup traces and inputs for annotation function + # Loop through the clause data and setup final annotations and trace variables + # Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule + if rule_type == 'node': + # Loop through all the head variable groundings and add it to the rules to be applied + # Loop through the clauses and add appropriate trace data and annotations + + # If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph + head_var_1_in_nodes = head_var_1 in nodes + add_head_var_node_to_graph = False + if allow_ground_rules and head_var_1_in_nodes: + groundings[head_var_1] = numba.typed.List([head_var_1]) + elif head_var_1 not in groundings: + if not head_var_1_in_nodes: + add_head_var_node_to_graph = True + groundings[head_var_1] = numba.typed.List([head_var_1]) + + for head_grounding in groundings[head_var_1]: + qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) + edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + + # Check for satisfaction one more time in case the refining process has changed the groundings + satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges) + if not satisfaction: + continue + + for i, clause in enumerate(clauses): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + + # 1. + if atom_trace: + if clause_var_1 == head_var_1: + qualified_nodes.append(numba.typed.List([head_grounding])) + else: + qualified_nodes.append(numba.typed.List(groundings[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + a.append(interpretations_node[head_grounding].world[clause_label]) + else: + for qn in groundings[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + # 1. + if atom_trace: + # Cases: Both equal, one equal, none equal + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + if clause_var_1 == head_var_1: + es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_1: + es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding]) + qualified_edges.append(es) + else: + qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)])) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + for e in groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_1: + for e in groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_grounding: + a.append(interpretations_edge[e].world[clause_label]) + else: + for qe in groundings_edges[(clause_var_1, clause_var_2)]: + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + else: + # Comparison clause (we do not handle for now) + pass + + # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) + if add_head_var_node_to_graph: + _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) + + # For each grounding add a rule to be applied + applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) + + elif rule_type == 'edge': + head_var_1 = head_variables[0] + head_var_2 = head_variables[1] + + # If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph + head_var_1_in_nodes = head_var_1 in nodes + head_var_2_in_nodes = head_var_2 in nodes + add_head_var_1_node_to_graph = False + add_head_var_2_node_to_graph = False + add_head_edge_to_graph = False + if allow_ground_rules and head_var_1_in_nodes: + groundings[head_var_1] = numba.typed.List([head_var_1]) + if allow_ground_rules and head_var_2_in_nodes: + groundings[head_var_2] = numba.typed.List([head_var_2]) + + if head_var_1 not in groundings: + if not head_var_1_in_nodes: + add_head_var_1_node_to_graph = True + groundings[head_var_1] = numba.typed.List([head_var_1]) + if head_var_2 not in groundings: + if not head_var_2_in_nodes: + add_head_var_2_node_to_graph = True + groundings[head_var_2] = numba.typed.List([head_var_2]) + + # Artificially connect the head variables with an edge if both of them were not in the graph + if not head_var_1_in_nodes and not head_var_2_in_nodes: + add_head_edge_to_graph = True + + head_var_1_groundings = groundings[head_var_1] + head_var_2_groundings = groundings[head_var_2] + + source, target, _ = rule_edges + infer_edges = True if source != '' and target != '' else False + + # Prepare the edges that we will loop over. + # For infer edges we loop over each combination pair + # Else we loop over the valid edges in the graph + valid_edge_groundings = numba.typed.List.empty_list(edge_type) + for g1 in head_var_1_groundings: + for g2 in head_var_2_groundings: + if infer_edges: + valid_edge_groundings.append((g1, g2)) + else: + if (g1, g2) in edges_set: + valid_edge_groundings.append((g1, g2)) + + # Loop through the head variable groundings + for valid_e in valid_edge_groundings: + head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1] + qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) + edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + + # Containers to keep track of groundings to make sure that the edge pair is valid + # We do this because we cannot know beforehand the edge matches from source groundings to target groundings + temp_groundings = groundings.copy() + temp_groundings_edges = groundings_edges.copy() + + # Refine the temp groundings for the specific edge head grounding + # We update the edge collection as well depending on if there's a match between the clause variables and head variables + temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding]) + temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding]) + for c1, c2 in temp_groundings_edges.keys(): + if c1 == head_var_1 and c2 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)]) + elif c1 == head_var_2 and c2 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)]) + elif c1 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding]) + elif c2 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding]) + elif c1 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding]) + elif c2 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding]) + + refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) + + # Check if the thresholds are still satisfied + # Check if all clauses are satisfied again in case the refining process changed anything + satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges) + + if not satisfaction: + continue + + if infer_edges: + # Prevent self loops while inferring edges if the clause variables are not the same + if source != target and head_var_1_grounding == head_var_2_grounding: + continue + edges_to_be_added[0].append(head_var_1_grounding) + edges_to_be_added[1].append(head_var_2_grounding) + + for i, clause in enumerate(clauses): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + # 1. + if atom_trace: + if clause_var_1 == head_var_1: + qualified_nodes.append(numba.typed.List([head_var_1_grounding])) + elif clause_var_1 == head_var_2: + qualified_nodes.append(numba.typed.List([head_var_2_grounding])) + else: + qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + a.append(interpretations_node[head_var_1_grounding].world[clause_label]) + elif clause_var_1 == head_var_2: + a.append(interpretations_node[head_var_2_grounding].world[clause_label]) + else: + for qn in temp_groundings[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + # 1. + if atom_trace: + # Cases: + # 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1) + # 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2) + # 3. None equal + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding]) + qualified_edges.append(es) + else: + qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)])) + + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + else: + for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]: + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + + # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) + if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1: + _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) + if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2: + _add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node) + if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding): + _add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge) + + # For each grounding combination add a rule to be applied + # Only if all the clauses have valid groundings + # if satisfaction: + e = (head_var_1_grounding, head_var_2_grounding) + applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) + + # Return the applicable rules + return applicable_rules_node, applicable_rules_edge + + +@numba.njit(cache=True) +def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges): + # Check if the thresholds are satisfied for each clause + satisfaction = True + for i, clause in enumerate(clauses): + # Unpack clause variables + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], groundings[clause_var_1], clause_label, thresholds[i]) and satisfaction + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], groundings_edges[(clause_var_1, clause_var_2)], clause_label, thresholds[i]) and satisfaction + return satisfaction + + +@numba.njit(cache=True) def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip): # Extract rule params rule_type = rule.get_type() @@ -757,7 +1214,7 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n # We return a list of tuples which specify the target nodes/edges that have made the rule body true applicable_rules = numba.typed.List.empty_list(node_applicable_rule_type) - + # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe # One array for each node, then condense into a single list later applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(node_applicable_rule_type) for _ in nodes]) @@ -785,6 +1242,7 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + clause_type_and_variables = numba.typed.List.empty_list(clause_data) satisfaction = True for i, clause in enumerate(clauses): @@ -795,28 +1253,16 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n clause_bnd = clause[3] clause_operator = clause[4] - # Unpack thresholds - # This value is total/available - threshold_quantifier_type = thresholds[i][1][1] - # This is a node clause # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes if clause_type == 'node': clause_var_1 = clause_variables[0] - subset = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors) + subset = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes) subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd) - if atom_trace: - qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) - - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in subsets[clause_var_1]: - a.append(interpretations_node[qn].world[clause_label]) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node', clause_label, numba.typed.List([clause_var_1]))) # This is an edge clause elif clause_type == 'edge': @@ -828,16 +1274,9 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n subsets[clause_var_1] = qe[0] subsets[clause_var_2] = qe[1] - if atom_trace: - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) - - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): - a.append(interpretations_edge[qe].world[clause_label]) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) + else: # This is a comparison clause # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1] @@ -852,8 +1291,8 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n # It's a node comparison if len(clause_variables) == 2: clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - subset_1 = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors) - subset_2 = get_node_rule_node_clause_subset(clause_var_2, target_node, subsets, neighbors) + subset_1 = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes) + subset_2 = get_node_rule_node_clause_subset(clause_var_2, target_node, subsets, nodes) # 1, 2 qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd) @@ -881,19 +1320,10 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n # Update subsets with final qualified nodes subsets[clause_var_1] = qualified_nodes_1 subsets[clause_var_2] = qualified_nodes_2 - qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) - qualified_comparison_nodes.extend(qualified_nodes_2) - if atom_trace: - qualified_nodes.append(qualified_comparison_nodes) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node-comparison', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) - # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in qualified_comparison_nodes: - a.append(interval.closed(1, 1)) - annotations.append(a) # Edge comparison. Compare stage else: satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator, @@ -907,39 +1337,19 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n subsets[clause_var_2_source] = qualified_nodes_2_source subsets[clause_var_2_target] = qualified_nodes_2_target - qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target)) - qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target)) - qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1) - qualified_comparison_nodes.extend(qualified_comparison_nodes_2) - - if atom_trace: - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - qualified_edges.append(qualified_comparison_nodes) - - # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qe in qualified_comparison_nodes: - a.append(interval.closed(1, 1)) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge-comparison', clause_label, numba.typed.List([clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target]))) # Non comparison clause else: - if threshold_quantifier_type == 'total': - if clause_type == 'node': - neigh_len = len(subset) - else: - neigh_len = sum([len(l) for l in subset_target]) - - # Available is all neighbors that have a particular label with bound inside [0,1] - elif threshold_quantifier_type == 'available': - if clause_type == 'node': - neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0,1))) - else: - neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0,1), reverse_graph)[0]) + if clause_type == 'node': + satisfaction = check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, thresholds[i]) and satisfaction + else: + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, thresholds[i], reverse_graph) and satisfaction - qualified_neigh_len = len(subsets[clause_var_1]) - satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction + # Refine subsets based on any updates + if satisfaction: + satisfaction = refine_subsets_node_rule(interpretations_edge, clauses, i, subsets, target_node, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph) and satisfaction # Exit loop if even one clause is not satisfied if not satisfaction: @@ -966,9 +1376,83 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n else: edges_to_be_added[1].append(target) + # Loop through the clause data and setup final annotations and trace variables + # 1. Add qualified nodes/edges to trace + # 2. Add annotations to annotation function variable + for i, clause in enumerate(clause_type_and_variables): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in subsets[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + + elif clause_type == 'node-comparison': + clause_var_1, clause_var_2 = clause_variables + qualified_nodes_1 = subsets[clause_var_1] + qualified_nodes_2 = subsets[clause_var_2] + qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) + qualified_comparison_nodes.extend(qualified_nodes_2) + # 1. + if atom_trace: + qualified_nodes.append(qualified_comparison_nodes) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in qualified_comparison_nodes: + a.append(interval.closed(1, 1)) + annotations.append(a) + + elif clause_type == 'edge-comparison': + clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables + qualified_nodes_1_source = subsets[clause_var_1_source] + qualified_nodes_1_target = subsets[clause_var_1_target] + qualified_nodes_2_source = subsets[clause_var_2_source] + qualified_nodes_2_target = subsets[clause_var_2_target] + qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target)) + qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target)) + qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1) + qualified_comparison_nodes.extend(qualified_comparison_nodes_2) + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + qualified_edges.append(qualified_comparison_nodes) + # 2. + # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qe in qualified_comparison_nodes: + a.append(interval.closed(1, 1)) + annotations.append(a) + # node/edge, annotations, qualified nodes, qualified edges, edges to be added applicable_rules_threadsafe[piter] = numba.typed.List([(target_node, annotations, qualified_nodes, qualified_edges, edges_to_be_added)]) - + # Merge all threadsafe rules into one single array for applicable_rule in applicable_rules_threadsafe: if len(applicable_rule) > 0: @@ -977,7 +1461,7 @@ def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, n return applicable_rules -@numba.njit(cache=False, parallel=True) +@numba.njit(cache=True) def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip): # Extract rule params rule_type = rule.get_type() @@ -988,7 +1472,7 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e # We return a list of tuples which specify the target nodes/edges that have made the rule body true applicable_rules = numba.typed.List.empty_list(edge_applicable_rule_type) - + # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe # One array for each node, then condense into a single list later applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(edge_applicable_rule_type) for _ in edges]) @@ -1016,6 +1500,7 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + clause_type_and_variables = numba.typed.List.empty_list(clause_data) satisfaction = True for i, clause in enumerate(clauses): @@ -1026,27 +1511,16 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e clause_bnd = clause[3] clause_operator = clause[4] - # Unpack thresholds - # This value is total/available - threshold_quantifier_type = thresholds[i][1][1] - # This is a node clause # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes if clause_type == 'node': clause_var_1 = clause_variables[0] - subset = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors) - + subset = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes) + subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd) - if atom_trace: - qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in subsets[clause_var_1]: - a.append(interpretations_node[qn].world[clause_label]) - annotations.append(a) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node', clause_label, numba.typed.List([clause_var_1]))) # This is an edge clause elif clause_type == 'edge': @@ -1058,17 +1532,9 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e subsets[clause_var_1] = qe[0] subsets[clause_var_2] = qe[1] - if atom_trace: - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) - - # Add annotations if necessary - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): - a.append(interpretations_edge[qe].world[clause_label]) - annotations.append(a) - + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) + else: # This is a comparison clause # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1] @@ -1079,17 +1545,17 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e # 2. get qualified nodes/edges as well as number associated for second predicate # 3. if there's no number in steps 1 or 2 return false clause # 4. do comparison with each qualified component from step 1 with each qualified component in step 2 - + # It's a node comparison if len(clause_variables) == 2: clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - subset_1 = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors) - subset_2 = get_edge_rule_node_clause_subset(clause_var_2, target_edge, subsets, neighbors) - + subset_1 = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes) + subset_2 = get_edge_rule_node_clause_subset(clause_var_2, target_edge, subsets, nodes) + # 1, 2 qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd) qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd) - + # It's an edge comparison elif len(clause_variables) == 4: clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3] @@ -1112,19 +1578,10 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e # Update subsets with final qualified nodes subsets[clause_var_1] = qualified_nodes_1 subsets[clause_var_2] = qualified_nodes_2 - qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) - qualified_comparison_nodes.extend(qualified_nodes_2) - if atom_trace: - qualified_nodes.append(qualified_comparison_nodes) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # Save data for annotations and atom trace + clause_type_and_variables.append(('node-comparison', clause_label, numba.typed.List([clause_var_1, clause_var_2]))) - # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qn in qualified_comparison_nodes: - a.append(interval.closed(1, 1)) - annotations.append(a) # Edge comparison. Compare stage else: satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator, @@ -1138,40 +1595,20 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e subsets[clause_var_2_source] = qualified_nodes_2_source subsets[clause_var_2_target] = qualified_nodes_2_target - qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target)) - qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target)) - qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1) - qualified_comparison_nodes.extend(qualified_comparison_nodes_2) - - if atom_trace: - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - qualified_edges.append(qualified_comparison_nodes) + # Save data for annotations and atom trace + clause_type_and_variables.append(('edge-comparison', clause_label, numba.typed.List([clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target]))) - # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - for qe in qualified_comparison_nodes: - a.append(interval.closed(1, 1)) - annotations.append(a) - # Non comparison clause else: - if threshold_quantifier_type == 'total': - if clause_type == 'node': - neigh_len = len(subset) - else: - neigh_len = sum([len(l) for l in subset_target]) - - # Available is all neighbors that have a particular label with bound inside [0,1] - elif threshold_quantifier_type == 'available': - if clause_type == 'node': - neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0, 1))) - else: - neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0, 1), reverse_graph)[0]) - - qualified_neigh_len = len(subsets[clause_var_1]) - satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction - + if clause_type == 'node': + satisfaction = check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, thresholds[i]) and satisfaction + else: + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, thresholds[i], reverse_graph) and satisfaction + + # Refine subsets based on any updates + if satisfaction: + satisfaction = refine_subsets_edge_rule(interpretations_edge, clauses, i, subsets, target_edge, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph) and satisfaction + # Exit loop if even one clause is not satisfied if not satisfaction: break @@ -1179,30 +1616,79 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e # Here we are done going through each clause of the rule # If all clauses we're satisfied, proceed to collect annotations and prepare edges to be added if satisfaction: - # Collect edges to be added - source, target, _ = rule_edges + # Loop through the clause data and setup final annotations and trace variables + # 1. Add qualified nodes/edges to trace + # 2. Add annotations to annotation function variable + for i, clause in enumerate(clause_type_and_variables): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List(subsets[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in subsets[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) - # Edges to be added - if source != '' and target != '': - # Check if edge nodes are source/target - if source == '__source': - edges_to_be_added[0].append(target_edge[0]) - elif source == '__target': - edges_to_be_added[0].append(target_edge[1]) - elif source in subsets: - edges_to_be_added[0].extend(subsets[source]) - else: - edges_to_be_added[0].append(source) + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2]))) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])): + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) - if target == '__source': - edges_to_be_added[1].append(target_edge[0]) - elif target == '__target': - edges_to_be_added[1].append(target_edge[1]) - elif target in subsets: - edges_to_be_added[1].extend(subsets[target]) - else: - edges_to_be_added[1].append(target) + elif clause_type == 'node-comparison': + clause_var_1, clause_var_2 = clause_variables + qualified_nodes_1 = subsets[clause_var_1] + qualified_nodes_2 = subsets[clause_var_2] + qualified_comparison_nodes = numba.typed.List(qualified_nodes_1) + qualified_comparison_nodes.extend(qualified_nodes_2) + # 1. + if atom_trace: + qualified_nodes.append(qualified_comparison_nodes) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qn in qualified_comparison_nodes: + a.append(interval.closed(1, 1)) + annotations.append(a) + elif clause_type == 'edge-comparison': + clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables + qualified_nodes_1_source = subsets[clause_var_1_source] + qualified_nodes_1_target = subsets[clause_var_1_target] + qualified_nodes_2_source = subsets[clause_var_2_source] + qualified_nodes_2_target = subsets[clause_var_2_target] + qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target)) + qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target)) + qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1) + qualified_comparison_nodes.extend(qualified_comparison_nodes_2) + # 1. + if atom_trace: + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + qualified_edges.append(qualified_comparison_nodes) + # 2. + # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + for qe in qualified_comparison_nodes: + a.append(interval.closed(1, 1)) + annotations.append(a) # node/edge, annotations, qualified nodes, qualified edges, edges to be added applicable_rules_threadsafe[piter] = numba.typed.List([(target_edge, annotations, qualified_nodes, qualified_edges, edges_to_be_added)]) @@ -1214,18 +1700,359 @@ def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, e return applicable_rules -@numba.njit(cache=False) -def get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors): +@numba.njit(cache=True) +def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors): + # Loop through the dependency graph and refine the groundings that have connections + all_variables_refined = numba.typed.List(clause_variables) + variables_just_refined = numba.typed.List(clause_variables) + new_variables_refined = numba.typed.List.empty_list(numba.types.string) + while len(variables_just_refined) > 0: + for refined_variable in variables_just_refined: + # Refine all the neighbors of the refined variable + if refined_variable in dependency_graph_neighbors: + for neighbor in dependency_graph_neighbors[refined_variable]: + old_edge_groundings = groundings_edges[(refined_variable, neighbor)] + new_node_groundings = groundings[refined_variable] + + # Delete old groundings for the variable being refined + del groundings[neighbor] + groundings[neighbor] = numba.typed.List.empty_list(node_type) + + # Update the edge groundings and node groundings + qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings]) + groundings_neighbor_set = set(groundings[neighbor]) + for e in qualified_groundings: + if e[1] not in groundings_neighbor_set: + groundings[neighbor].append(e[1]) + groundings_neighbor_set.add(e[1]) + groundings_edges[(refined_variable, neighbor)] = qualified_groundings + + # Add the neighbor to the list of refined variables so that we can refine for all its neighbors + if neighbor not in all_variables_refined: + new_variables_refined.append(neighbor) + + if refined_variable in dependency_graph_reverse_neighbors: + for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]: + old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)] + new_node_groundings = groundings[refined_variable] + + # Delete old groundings for the variable being refined + del groundings[reverse_neighbor] + groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type) + + # Update the edge groundings and node groundings + qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings]) + groundings_reverse_neighbor_set = set(groundings[reverse_neighbor]) + for e in qualified_groundings: + if e[0] not in groundings_reverse_neighbor_set: + groundings[reverse_neighbor].append(e[0]) + groundings_reverse_neighbor_set.add(e[0]) + groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings + + # Add the neighbor to the list of refined variables so that we can refine for all its neighbors + if reverse_neighbor not in all_variables_refined: + new_variables_refined.append(reverse_neighbor) + + variables_just_refined = numba.typed.List(new_variables_refined) + all_variables_refined.extend(new_variables_refined) + new_variables_refined.clear() + + +@numba.njit(cache=True) +def refine_subsets_node_rule(interpretations_edge, clauses, i, subsets, target_node, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph): + """NOTE: DEPRECATED""" + # Loop through all clauses till clause i-1 and update subsets recursively + # Then check if the clause still satisfies the thresholds + clause = clauses[i] + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + clause_bnd = clause[3] + clause_operator = clause[4] + + # Keep track of the variables that were refined (start with clause_variables) and variables that need refining + satisfaction = True + all_variables_refined = numba.typed.List(clause_variables) + variables_just_refined = numba.typed.List(clause_variables) + new_variables_refined = numba.typed.List.empty_list(numba.types.string) + while len(variables_just_refined) > 0: + for j in range(i): + c = clauses[j] + c_type = c[0] + c_label = c[1] + c_variables = c[2] + c_bnd = c[3] + c_operator = c[4] + + # If it is an edge clause or edge comparison clause, check if any of clause_variables are in c_variables + # If yes, then update the variable that is with it in the clause + if c_type == 'edge' or (c_type == 'comparison' and len(c_variables) > 2): + for v in variables_just_refined: + for k, cv in enumerate(c_variables): + if cv == v: + # Find which variable needs to be refined, 1st or 2nd. + # 2nd variable needs refining + if k == 0: + refine_idx = 1 + refine_v = c_variables[1] + # 1st variable needs refining + elif k == 1: + refine_idx = 0 + refine_v = c_variables[0] + # 2nd variable needs refining + elif k == 2: + refine_idx = 1 + refine_v = c_variables[3] + # 1st variable needs refining + else: + refine_idx = 0 + refine_v = c_variables[2] + + # Refine the variable + if refine_v not in all_variables_refined: + new_variables_refined.append(refine_v) + + if c_type == 'edge': + clause_var_1, clause_var_2 = (refine_v, cv) if refine_idx == 0 else (cv, refine_v) + del subsets[refine_v] + subset_source, subset_target = get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes) + + # Get qualified edges + qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, c_label, c_bnd, reverse_graph) + subsets[clause_var_1] = qe[0] + subsets[clause_var_2] = qe[1] + + # Check if we still satisfy the clause + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, c_label, thresholds[j], reverse_graph) and satisfaction + else: + # We do not support refinement for comparison clauses + pass + + if not satisfaction: + return satisfaction + + variables_just_refined = numba.typed.List(new_variables_refined) + all_variables_refined.extend(new_variables_refined) + new_variables_refined.clear() + + return satisfaction + + +@numba.njit(cache=True) +def refine_subsets_edge_rule(interpretations_edge, clauses, i, subsets, target_edge, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph): + """NOTE: DEPRECATED""" + # Loop through all clauses till clause i-1 and update subsets recursively + # Then check if the clause still satisfies the thresholds + clause = clauses[i] + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + clause_bnd = clause[3] + clause_operator = clause[4] + + # Keep track of the variables that were refined (start with clause_variables) and variables that need refining + satisfaction = True + all_variables_refined = numba.typed.List(clause_variables) + variables_just_refined = numba.typed.List(clause_variables) + new_variables_refined = numba.typed.List.empty_list(numba.types.string) + while len(variables_just_refined) > 0: + for j in range(i): + c = clauses[j] + c_type = c[0] + c_label = c[1] + c_variables = c[2] + c_bnd = c[3] + c_operator = c[4] + + # If it is an edge clause or edge comparison clause, check if any of clause_variables are in c_variables + # If yes, then update the variable that is with it in the clause + if c_type == 'edge' or (c_type == 'comparison' and len(c_variables) > 2): + for v in variables_just_refined: + for k, cv in enumerate(c_variables): + if cv == v: + # Find which variable needs to be refined, 1st or 2nd. + # 2nd variable needs refining + if k == 0: + refine_idx = 1 + refine_v = c_variables[1] + # 1st variable needs refining + elif k == 1: + refine_idx = 0 + refine_v = c_variables[0] + # 2nd variable needs refining + elif k == 2: + refine_idx = 1 + refine_v = c_variables[3] + # 1st variable needs refining + else: + refine_idx = 0 + refine_v = c_variables[2] + + # Refine the variable + if refine_v not in all_variables_refined: + new_variables_refined.append(refine_v) + + if c_type == 'edge': + clause_var_1, clause_var_2 = (refine_v, cv) if refine_idx == 0 else (cv, refine_v) + del subsets[refine_v] + subset_source, subset_target = get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes) + + # Get qualified edges + qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, c_label, c_bnd, reverse_graph) + subsets[clause_var_1] = qe[0] + subsets[clause_var_2] = qe[1] + + # Check if we still satisfy the clause + satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, c_label, thresholds[j], reverse_graph) and satisfaction + else: + # We do not support refinement for comparison clauses + pass + + if not satisfaction: + return satisfaction + + variables_just_refined = numba.typed.List(new_variables_refined) + all_variables_refined.extend(new_variables_refined) + new_variables_refined.clear() + + return satisfaction + + +@numba.njit(cache=True) +def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold): + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(grounding) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1))) + + qualified_neigh_len = len(qualified_grounding) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold): + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(grounding) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1))) + + qualified_neigh_len = len(qualified_grounding) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, threshold): + """NOTE: DEPRECATED""" + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(subset) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0, 1))) + + # Only take length of clause_var_1 because length of subsets of var_1 and var_2 are supposed to be equal + qualified_neigh_len = len(subsets[clause_var_1]) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, threshold, reverse_graph): + """NOTE: DEPRECATED""" + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = sum([len(l) for l in subset_target]) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0, 1), reverse_graph)[0]) + + qualified_neigh_len = len(subsets[clause_var_1]) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): + # The groundings for a node clause can be either a previous grounding or all possible nodes + if l in predicate_map: + grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] + else: + grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] + return grounding + + +@numba.njit(cache=True) +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): + # There are 4 cases for predicate(Y,Z): + # 1. Both predicate variables Y and Z have not been encountered before + # 2. The source variable Y has not been encountered before but the target variable Z has + # 3. The target variable Z has not been encountered before but the source variable Y has + # 4. Both predicate variables Y and Z have been encountered before + edge_groundings = numba.typed.List.empty_list(edge_type) + + # Case 1: + # We replace Y by all nodes and Z by the neighbors of each of these nodes + if clause_var_1 not in groundings and clause_var_2 not in groundings: + if l in predicate_map: + edge_groundings = predicate_map[l] + else: + edge_groundings = edges + + # Case 2: + # We replace Y by the sources of Z + elif clause_var_1 not in groundings and clause_var_2 in groundings: + for n in groundings[clause_var_2]: + es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]]) + edge_groundings.extend(es) + + # Case 3: + # We replace Z by the neighbors of Y + elif clause_var_1 in groundings and clause_var_2 not in groundings: + for n in groundings[clause_var_1]: + es = numba.typed.List([(n, nn) for nn in neighbors[n]]) + edge_groundings.extend(es) + + # Case 4: + # We have seen both variables before + else: + # We have already seen these two variables in an edge clause + if (clause_var_1, clause_var_2) in groundings_edges: + edge_groundings = groundings_edges[(clause_var_1, clause_var_2)] + # We have seen both these variables but not in an edge clause together + else: + groundings_clause_var_2_set = set(groundings[clause_var_2]) + for n in groundings[clause_var_1]: + es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set]) + edge_groundings.extend(es) + + return edge_groundings + + +@numba.njit(cache=True) +def get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes): + """NOTE: DEPRECATED""" # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes if clause_var_1 == '__target': subset = numba.typed.List([target_node]) else: - subset = neighbors[target_node] if clause_var_1 not in subsets else subsets[clause_var_1] + nodes_without_target = numba.typed.List([n for n in nodes if n != target_node]) + subset = nodes_without_target if clause_var_1 not in subsets else subsets[clause_var_1] return subset -@numba.njit(cache=False) +@numba.njit(cache=True) def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes): + """NOTE: DEPRECATED""" # There are 5 cases for predicate(Y,Z): # 1. Either one or both of Y, Z are the target node # 2. Both predicate variables Y and Z have not been encountered before @@ -1252,10 +2079,10 @@ def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, su subset_target = numba.typed.List([numba.typed.List([target_node]) for _ in subset_source]) # Case 2: - # We replace Y by all nodes and Z by the neighbors of each of these nodes + # We replace Y by all nodes (except target_node) and Z by the neighbors of each of these nodes elif clause_var_1 not in subsets and clause_var_2 not in subsets: - subset_source = numba.typed.List(nodes) - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_source = numba.typed.List([n for n in nodes if n != target_node]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_node]) for n in subset_source]) # Case 3: # We replace Y by the sources of Z @@ -1266,37 +2093,50 @@ def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, su for n in subsets[clause_var_2]: sources = reverse_neighbors[n] for source in sources: - subset_source.append(source) - subset_target.append(numba.typed.List([n])) + if source != target_node: + subset_source.append(source) + subset_target.append(numba.typed.List([n])) # Case 4: # We replace Z by the neighbors of Y elif clause_var_1 in subsets and clause_var_2 not in subsets: subset_source = subsets[clause_var_1] - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_node]) for n in subset_source]) # Case 5: else: subset_source = subsets[clause_var_1] subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source]) + # If any of the subsets are empty return them in the correct type + if len(subset_source) == 0: + subset_source = numba.typed.List.empty_list(node_type) + subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + # If any sub lists in subset target are empty, add correct type for empty list + for i, t in enumerate(subset_target): + if len(t) == 0: + subset_target[i] = numba.typed.List.empty_list(node_type) + return subset_source, subset_target -@numba.njit(cache=False) -def get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors): +@numba.njit(cache=True) +def get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes): + """NOTE: DEPRECATED""" # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes if clause_var_1 == '__source': subset = numba.typed.List([target_edge[0]]) elif clause_var_1 == '__target': subset = numba.typed.List([target_edge[1]]) else: - subset = neighbors[target_edge[0]] if clause_var_1 not in subsets else subsets[clause_var_1] + nodes_without_target_or_source = numba.typed.List([n for n in nodes if n != target_edge[0] and n != target_edge[1]]) + subset = nodes_without_target_or_source if clause_var_1 not in subsets else subsets[clause_var_1] return subset -@numba.njit(cache=False) +@numba.njit(cache=True) def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes): + """NOTE: DEPRECATED""" # There are 5 cases for predicate(Y,Z): # 1. Either one or both of Y, Z are the source or target node # 2. Both predicate variables Y and Z have not been encountered before @@ -1342,10 +2182,10 @@ def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, su subset_target = numba.typed.List([numba.typed.List([target_edge[1]]) for _ in subset_source]) # Case 2: - # We replace Y by all nodes and Z by the neighbors of each of these nodes + # We replace Y by all nodes (except source/target) and Z by the neighbors of each of these nodes elif clause_var_1 not in subsets and clause_var_2 not in subsets: - subset_source = numba.typed.List(nodes) - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_source = numba.typed.List([n for n in nodes if n != target_edge[0] and n != target_edge[1]]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_edge[0] and nn != target_edge[1]]) for n in subset_source]) # Case 3: # We replace Y by the sources of Z @@ -1356,36 +2196,70 @@ def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, su for n in subsets[clause_var_2]: sources = reverse_neighbors[n] for source in sources: - subset_source.append(source) - subset_target.append(numba.typed.List([n])) + if source != target_edge[0] and source != target_edge[1]: + subset_source.append(source) + subset_target.append(numba.typed.List([n])) # Case 4: # We replace Z by the neighbors of Y elif clause_var_1 in subsets and clause_var_2 not in subsets: subset_source = subsets[clause_var_1] - subset_target = numba.typed.List([neighbors[n] for n in subset_source]) + subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_edge[0] and nn != target_edge[1]]) for n in subset_source]) # Case 5: else: subset_source = subsets[clause_var_1] subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source]) + # If any of the subsets are empty return them in the correct type + if len(subset_source) == 0: + subset_source = numba.typed.List.empty_list(node_type) + subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + # If any sub lists in subset target are empty, add correct type for empty list + for i, t in enumerate(subset_target): + if len(t) == 0: + subset_target[i] = numba.typed.List.empty_list(node_type) + return subset_source, subset_target -@numba.njit(cache=False) +@numba.njit(cache=True) +def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd): + # Filter the grounding by the predicate and bound of the clause + qualified_groundings = numba.typed.List.empty_list(node_type) + for n in grounding: + if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd)): + qualified_groundings.append(n) + + return qualified_groundings + + +@numba.njit(cache=True) +def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd): + # Filter the grounding by the predicate and bound of the clause + qualified_groundings = numba.typed.List.empty_list(edge_type) + for e in grounding: + if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd)): + qualified_groundings.append(e) + + return qualified_groundings + + +@numba.njit(cache=True) def get_qualified_components_node_clause(interpretations_node, candidates, l, bnd): + """NOTE: DEPRECATED""" # Get all the qualified neighbors for a particular clause qualified_nodes = numba.typed.List.empty_list(node_type) for n in candidates: - if is_satisfied_node(interpretations_node, n, (l, bnd)): + if is_satisfied_node(interpretations_node, n, (l, bnd)) and n not in qualified_nodes: qualified_nodes.append(n) return qualified_nodes -@numba.njit(cache=False) +@numba.njit(cache=True) def get_qualified_components_node_comparison_clause(interpretations_node, candidates, l, bnd): + """NOTE: DEPRECATED""" # Get all the qualified neighbors for a particular comparison clause and return them along with the number associated qualified_nodes = numba.typed.List.empty_list(node_type) qualified_nodes_numbers = numba.typed.List.empty_list(numba.types.float64) @@ -1398,8 +2272,9 @@ def get_qualified_components_node_comparison_clause(interpretations_node, candid return qualified_nodes, qualified_nodes_numbers -@numba.njit(cache=False) +@numba.njit(cache=True) def get_qualified_components_edge_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph): + """NOTE: DEPRECATED""" # Get all the qualified sources and targets for a particular clause qualified_nodes_source = numba.typed.List.empty_list(node_type) qualified_nodes_target = numba.typed.List.empty_list(node_type) @@ -1413,8 +2288,9 @@ def get_qualified_components_edge_clause(interpretations_edge, candidates_source return qualified_nodes_source, qualified_nodes_target -@numba.njit(cache=False) +@numba.njit(cache=True) def get_qualified_components_edge_comparison_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph): + """NOTE: DEPRECATED""" # Get all the qualified sources and targets for a particular clause qualified_nodes_source = numba.typed.List.empty_list(node_type) qualified_nodes_target = numba.typed.List.empty_list(node_type) @@ -1431,8 +2307,9 @@ def get_qualified_components_edge_comparison_clause(interpretations_edge, candid return qualified_nodes_source, qualified_nodes_target, qualified_edges_numbers -@numba.njit(cache=False) +@numba.njit(cache=True) def compare_numbers_node_predicate(numbers_1, numbers_2, op, qualified_nodes_1, qualified_nodes_2): + """NOTE: DEPRECATED""" result = False final_qualified_nodes_1 = numba.typed.List.empty_list(node_type) final_qualified_nodes_2 = numba.typed.List.empty_list(node_type) @@ -1463,8 +2340,9 @@ def compare_numbers_node_predicate(numbers_1, numbers_2, op, qualified_nodes_1, return result, final_qualified_nodes_1, final_qualified_nodes_2 -@numba.njit(cache=False) +@numba.njit(cache=True) def compare_numbers_edge_predicate(numbers_1, numbers_2, op, qualified_nodes_1a, qualified_nodes_1b, qualified_nodes_2a, qualified_nodes_2b): + """NOTE: DEPRECATED""" result = False final_qualified_nodes_1a = numba.typed.List.empty_list(node_type) final_qualified_nodes_1b = numba.typed.List.empty_list(node_type) @@ -1499,7 +2377,7 @@ def compare_numbers_edge_predicate(numbers_1, numbers_2, op, qualified_nodes_1a, return result, final_qualified_nodes_1a, final_qualified_nodes_1b, final_qualified_nodes_2a, final_qualified_nodes_2b -@numba.njit(cache=False) +@numba.njit(cache=True) def _satisfies_threshold(num_neigh, num_qualified_component, threshold): # Checks if qualified neighbors satisfy threshold. This is for one clause if threshold[1][0]=='number': @@ -1531,8 +2409,8 @@ def _satisfies_threshold(num_neigh, num_qualified_component, threshold): return result -@numba.njit(cache=False) -def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): +@numba.njit(cache=True) +def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): updated = False # This is to prevent a key error in case the label is a specific label try: @@ -1543,6 +2421,10 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat # Add label to world if it is not there if l not in world.world: world.world[l] = interval.closed(0, 1) + if l in predicate_map: + predicate_map[l].append(comp) + else: + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd prev_bnd = world.world[l].copy() @@ -1575,7 +2457,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1==l: + if p1 == l: + if p2 not in world.world: + world.world[p2] = interval.closed(0, 1) + if p2 in predicate_map: + predicate_map[p2].append(comp) + else: + predicate_map[p2] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) @@ -1586,7 +2474,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2==l: + if p2 == l: + if p1 not in world.world: + world.world[p1] = interval.closed(0, 1) + if p1 in predicate_map: + predicate_map[p1].append(comp) + else: + predicate_map[p1] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) @@ -1620,8 +2514,8 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat return (False, 0) -@numba.njit(cache=False) -def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): +@numba.njit(cache=True) +def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): updated = False # This is to prevent a key error in case the label is a specific label try: @@ -1632,6 +2526,10 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat # Add label to world if it is not there if l not in world.world: world.world[l] = interval.closed(0, 1) + if l in predicate_map: + predicate_map[l].append(comp) + else: + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd prev_bnd = world.world[l].copy() @@ -1664,7 +2562,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1==l: + if p1 == l: + if p2 not in world.world: + world.world[p2] = interval.closed(0, 1) + if p2 in predicate_map: + predicate_map[p2].append(comp) + else: + predicate_map[p2] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) @@ -1675,7 +2579,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2==l: + if p2 == l: + if p1 not in world.world: + world.world[p1] = interval.closed(0, 1) + if p1 in predicate_map: + predicate_map[p1].append(comp) + else: + predicate_map[p1] = numba.typed.List([comp]) if atom_trace: _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) @@ -1686,7 +2596,7 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper))) - + # Gather convergence data change = 0 if updated: @@ -1702,29 +2612,29 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat change = max(change, max_delta) else: change = 1 + ip_update_cnt - + return (updated, change) except: return (False, 0) -@numba.njit(cache=False) +@numba.njit(cache=True) def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): rule_trace.append((qn, qe, prev_bnd.copy(), name)) - -@numba.njit(cache=False) + +@numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (label, interval) in nas: - result = result and is_satisfied_node(interpretations, comp, (label, interval)) + for (l, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (l, bnd)) return result -@numba.njit(cache=False) +@numba.njit(cache=True) def is_satisfied_node(interpretations, comp, na): result = False - if (not (na[0] is None or na[1] is None)): + if not (na[0] is None or na[1] is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] @@ -1736,7 +2646,7 @@ def is_satisfied_node(interpretations, comp, na): return result -@numba.njit(cache=False) +@numba.njit(cache=True) def is_satisfied_node_comparison(interpretations, comp, na): result = False number = 0 @@ -1763,18 +2673,18 @@ def is_satisfied_node_comparison(interpretations, comp, na): return result, number -@numba.njit(cache=False) +@numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (label, interval) in nas: - result = result and is_satisfied_edge(interpretations, comp, (label, interval)) + for (l, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) return result -@numba.njit(cache=False) +@numba.njit(cache=True) def is_satisfied_edge(interpretations, comp, na): result = False - if (not (na[0] is None or na[1] is None)): + if not (na[0] is None or na[1] is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] @@ -1786,7 +2696,7 @@ def is_satisfied_edge(interpretations, comp, na): return result -@numba.njit(cache=False) +@numba.njit(cache=True) def is_satisfied_edge_comparison(interpretations, comp, na): result = False number = 0 @@ -1813,7 +2723,7 @@ def is_satisfied_edge_comparison(interpretations, comp, na): return result, number -@numba.njit(cache=False) +@numba.njit(cache=True) def annotate(annotation_functions, rule, annotations, weights): func_name = rule.get_annotation_function() if func_name == '': @@ -1826,7 +2736,7 @@ def annotate(annotation_functions, rule, annotations, weights): return annotation -@numba.njit(cache=False) +@numba.njit(cache=True) def check_consistent_node(interpretations, comp, na): world = interpretations[comp] if na[0] in world.world: @@ -1839,7 +2749,7 @@ def check_consistent_node(interpretations, comp, na): return True -@numba.njit(cache=False) +@numba.njit(cache=True) def check_consistent_edge(interpretations, comp, na): world = interpretations[comp] if na[0] in world.world: @@ -1852,20 +2762,26 @@ def check_consistent_edge(interpretations, comp, na): return True -@numba.njit(cache=False) -def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes): +@numba.njit(cache=True) +def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): world = interpretations[comp] if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) + if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: + name = facts_to_be_applied_trace[idx] + elif mode == 'rule' and atom_trace: + name = rules_to_be_applied_trace[idx][2] + else: + name = '-' if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], f'Inconsistency due to {name}') # Resolve inconsistency and set static world.world[na[0]].set_lower_upper(0, 1) world.world[na[0]].set_static(True) for p1, p2 in ipl: if p1==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'Inconsistency due to {name}') world.world[p2].set_lower_upper(0, 1) world.world[p2].set_static(True) if store_interpretation_changes: @@ -1873,28 +2789,34 @@ def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, at if p2==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'Inconsistency due to {name}') world.world[p1].set_lower_upper(0, 1) world.world[p1].set_static(True) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1))) - # Add inconsistent predicates to a list + # Add inconsistent predicates to a list -@numba.njit(cache=False) -def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes): +@numba.njit(cache=True) +def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): w = interpretations[comp] if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) + if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: + name = facts_to_be_applied_trace[idx] + elif mode == 'rule' and atom_trace: + name = rules_to_be_applied_trace[idx][2] + else: + name = '-' if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], f'Inconsistency due to {name}') # Resolve inconsistency and set static w.world[na[0]].set_lower_upper(0, 1) w.world[na[0]].set_static(True) for p1, p2 in ipl: if p1==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], f'Inconsistency due to {name}') w.world[p2].set_lower_upper(0, 1) w.world[p2].set_static(True) if store_interpretation_changes: @@ -1902,14 +2824,14 @@ def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, at if p2==na[0]: if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], 'Inconsistency') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], f'Inconsistency due to {name}') w.world[p1].set_lower_upper(0, 1) w.world[p1].set_static(True) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1))) -@numba.njit(cache=False) +@numba.njit(cache=True) def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): nodes.append(node) neighbors[node] = numba.typed.List.empty_list(node_type) @@ -1917,8 +2839,8 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): interpretations_node[node] = world.World(numba.typed.List.empty_list(label.label_type)) -@numba.njit(cache=False) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge): +@numba.njit(cache=True) +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1938,6 +2860,10 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int reverse_neighbors[target].append(source) if l.value!='': interpretations_edge[edge] = world.World(numba.typed.List([l])) + if l in predicate_map: + predicate_map[l].append(edge) + else: + predicate_map[l] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: @@ -1948,33 +2874,39 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int return edge, new_edge -@numba.njit(cache=False) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge): +@numba.njit(cache=True) +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes -@numba.njit(cache=False) -def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge): +@numba.njit(cache=True) +def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map): source, target = edge edges.remove(edge) del interpretations_edge[edge] + for l in predicate_map: + if edge in predicate_map[l]: + predicate_map[l].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) -@numba.njit(cache=False) -def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): +@numba.njit(cache=True) +def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map): nodes.remove(node) del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] + for l in predicate_map: + if node in predicate_map[l]: + predicate_map[l].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): @@ -1985,7 +2917,7 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node reverse_neighbors[n].remove(node) -@numba.njit(cache=False) +@numba.njit(cache=True) def float_to_str(value): number = int(value) decimal = int(value % 1 * 1000) @@ -1993,7 +2925,7 @@ def float_to_str(value): return float_str -@numba.njit(cache=False) +@numba.njit(cache=True) def str_to_float(value): decimal_pos = value.find('.') if decimal_pos != -1: @@ -2006,7 +2938,7 @@ def str_to_float(value): return value -@numba.njit(cache=False) +@numba.njit(cache=True) def str_to_int(value): if value[0] == '-': negative = True diff --git a/pyreason/scripts/interval/interval.py b/pyreason/scripts/interval/interval.py index a6ab76e..a6b7fb6 100755 --- a/pyreason/scripts/interval/interval.py +++ b/pyreason/scripts/interval/interval.py @@ -2,6 +2,7 @@ from numba import njit import numpy as np + class Interval(structref.StructRefProxy): def __new__(cls, l, u, s=False): return structref.StructRefProxy.__new__(cls, l, u, s, l, u) diff --git a/pyreason/scripts/numba_wrapper/numba_types/rule_type.py b/pyreason/scripts/numba_wrapper/numba_types/rule_type.py index 970d710..766114d 100755 --- a/pyreason/scripts/numba_wrapper/numba_types/rule_type.py +++ b/pyreason/scripts/numba_wrapper/numba_types/rule_type.py @@ -32,8 +32,8 @@ def typeof_rule(val, c): # Construct object from Numba functions (Doesn't work. We don't need this currently) @type_callable(Rule) def type_rule(context): - def typer(rule_name, type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule): - if isinstance(rule_name, types.UnicodeType) and isinstance(type, types.UnicodeType) and isinstance(target, label.LabelType) and isinstance(delta, types.Integer) and isinstance(clauses, (types.NoneType, types.ListType)) and isinstance(bnd, interval.IntervalType) and isinstance(thresholds, types.ListType) and isinstance(ann_fn, types.UnicodeType) and isinstance(weights, types.Array) and isinstance(edges, types.Tuple) and isinstance(static, types.Boolean) and isinstance(immediate_rule, types.Boolean): + def typer(rule_name, type, target, head_variables, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule): + if isinstance(rule_name, types.UnicodeType) and isinstance(type, types.UnicodeType) and isinstance(target, label.LabelType) and isinstance(head_variables, types.ListType) and isinstance(delta, types.Integer) and isinstance(clauses, (types.NoneType, types.ListType)) and isinstance(bnd, interval.IntervalType) and isinstance(thresholds, types.ListType) and isinstance(ann_fn, types.UnicodeType) and isinstance(weights, types.Array) and isinstance(edges, types.Tuple) and isinstance(static, types.Boolean) and isinstance(immediate_rule, types.Boolean): return rule_type return typer @@ -46,6 +46,7 @@ def __init__(self, dmm, fe_type): ('rule_name', types.string), ('type', types.string), ('target', label.label_type), + ('head_variables', types.ListType(types.string)), ('delta', types.uint16), ('clauses', types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string)))), ('bnd', interval.interval_type), @@ -63,6 +64,7 @@ def __init__(self, dmm, fe_type): make_attribute_wrapper(RuleType, 'rule_name', 'rule_name') make_attribute_wrapper(RuleType, 'type', 'type') make_attribute_wrapper(RuleType, 'target', 'target') +make_attribute_wrapper(RuleType, 'head_variables', 'head_variables') make_attribute_wrapper(RuleType, 'delta', 'delta') make_attribute_wrapper(RuleType, 'clauses', 'clauses') make_attribute_wrapper(RuleType, 'bnd', 'bnd') @@ -75,16 +77,18 @@ def __init__(self, dmm, fe_type): # Implement constructor -@lower_builtin(Rule, types.string, types.string, label.label_type, types.uint16, types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), interval.interval_type, types.ListType(types.ListType(types.Tuple((types.string, types.string, types.float64)))), types.string, types.float64[::1], types.Tuple((types.string, types.string, label.label_type)), types.boolean, types.boolean) +@lower_builtin(Rule, types.string, types.string, label.label_type, types.ListType(types.string), types.uint16, types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), interval.interval_type, types.ListType(types.ListType(types.Tuple((types.string, types.string, types.float64)))), types.string, types.float64[::1], types.Tuple((types.string, types.string, label.label_type)), types.boolean, types.boolean) def impl_rule(context, builder, sig, args): typ = sig.return_type - rule_name, type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule = args + rule_name, type, target, head_variables, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule = args + context.nrt.incref(builder, types.ListType(types.string), head_variables) context.nrt.incref(builder, types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), clauses) context.nrt.incref(builder, types.ListType(types.Tuple((types.string, types.UniTuple(types.string, 2), types.float64))), thresholds) rule = cgutils.create_struct_proxy(typ)(context, builder) rule.rule_name = rule_name rule.type = type rule.target = target + rule.head_variables = head_variables rule.delta = delta rule.clauses = clauses rule.bnd = bnd @@ -119,6 +123,13 @@ def getter(rule): return getter +@overload_method(RuleType, "get_head_variables") +def get_head_variables(rule): + def getter(rule): + return rule.head_variables + return getter + + @overload_method(RuleType, "get_delta") def get_delta(rule): def getter(rule): @@ -133,6 +144,13 @@ def getter(rule): return getter +@overload_method(RuleType, "set_clauses") +def set_clauses(rule): + def setter(rule, clauses): + rule.clauses = clauses + return setter + + @overload_method(RuleType, "get_bnd") def get_bnd(rule): def impl(rule): @@ -188,6 +206,7 @@ def unbox_rule(typ, obj, c): name_obj = c.pyapi.object_getattr_string(obj, "_rule_name") type_obj = c.pyapi.object_getattr_string(obj, "_type") target_obj = c.pyapi.object_getattr_string(obj, "_target") + head_variables_obj = c.pyapi.object_getattr_string(obj, "_head_variables") delta_obj = c.pyapi.object_getattr_string(obj, "_delta") clauses_obj = c.pyapi.object_getattr_string(obj, "_clauses") bnd_obj = c.pyapi.object_getattr_string(obj, "_bnd") @@ -201,6 +220,7 @@ def unbox_rule(typ, obj, c): rule.rule_name = c.unbox(types.string, name_obj).value rule.type = c.unbox(types.string, type_obj).value rule.target = c.unbox(label.label_type, target_obj).value + rule.head_variables = c.unbox(types.ListType(types.string), head_variables_obj).value rule.delta = c.unbox(types.uint16, delta_obj).value rule.clauses = c.unbox(types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), clauses_obj).value rule.bnd = c.unbox(interval.interval_type, bnd_obj).value @@ -213,6 +233,7 @@ def unbox_rule(typ, obj, c): c.pyapi.decref(name_obj) c.pyapi.decref(type_obj) c.pyapi.decref(target_obj) + c.pyapi.decref(head_variables_obj) c.pyapi.decref(delta_obj) c.pyapi.decref(clauses_obj) c.pyapi.decref(bnd_obj) @@ -233,6 +254,7 @@ def box_rule(typ, val, c): name_obj = c.box(types.string, rule.rule_name) type_obj = c.box(types.string, rule.type) target_obj = c.box(label.label_type, rule.target) + head_variables_obj = c.box(types.ListType(types.string), rule.head_variables) delta_obj = c.box(types.uint16, rule.delta) clauses_obj = c.box(types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), rule.clauses) bnd_obj = c.box(interval.interval_type, rule.bnd) @@ -242,10 +264,11 @@ def box_rule(typ, val, c): edges_obj = c.box(types.Tuple((types.string, types.string, label.label_type)), rule.edges) static_obj = c.box(types.boolean, rule.static) immediate_rule_obj = c.box(types.boolean, rule.immediate_rule) - res = c.pyapi.call_function_objargs(class_obj, (name_obj, type_obj, target_obj, delta_obj, clauses_obj, bnd_obj, thresholds_obj, ann_fn_obj, weights_obj, edges_obj, static_obj, immediate_rule_obj)) + res = c.pyapi.call_function_objargs(class_obj, (name_obj, type_obj, target_obj, head_variables_obj, delta_obj, clauses_obj, bnd_obj, thresholds_obj, ann_fn_obj, weights_obj, edges_obj, static_obj, immediate_rule_obj)) c.pyapi.decref(name_obj) c.pyapi.decref(type_obj) c.pyapi.decref(target_obj) + c.pyapi.decref(head_variables_obj) c.pyapi.decref(delta_obj) c.pyapi.decref(clauses_obj) c.pyapi.decref(ann_fn_obj) diff --git a/pyreason/scripts/program/program.py b/pyreason/scripts/program/program.py index 4c992ae..8adc5e8 100755 --- a/pyreason/scripts/program/program.py +++ b/pyreason/scripts/program/program.py @@ -8,7 +8,7 @@ class Program: specific_node_labels = [] specific_edge_labels = [] - def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, parallel_computing, update_mode): + def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, parallel_computing, update_mode, allow_ground_rules): self._graph = graph self._facts_node = facts_node self._facts_edge = facts_edge @@ -23,6 +23,7 @@ def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functio self._store_interpretation_changes = store_interpretation_changes self._parallel_computing = parallel_computing self._update_mode = update_mode + self._allow_ground_rules = allow_ground_rules self.interp = None def reason(self, tmax, convergence_threshold, convergence_bound_threshold, verbose=True): @@ -35,9 +36,9 @@ def reason(self, tmax, convergence_threshold, convergence_bound_threshold, verbo # Instantiate correct interpretation class based on whether we parallelize the code or not. (We cannot parallelize with cache on) if self._parallel_computing: - self.interp = InterpretationParallel(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode) + self.interp = InterpretationParallel(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_rules) else: - self.interp = Interpretation(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode) + self.interp = Interpretation(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_rules) self.interp.start_fp(self._tmax, self._facts_node, self._facts_edge, self._rules, verbose, convergence_threshold, convergence_bound_threshold) return self.interp diff --git a/pyreason/scripts/rules/rule.py b/pyreason/scripts/rules/rule.py index 73824c4..0ac1749 100755 --- a/pyreason/scripts/rules/rule.py +++ b/pyreason/scripts/rules/rule.py @@ -9,14 +9,14 @@ class Rule: 1. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0 TODO: Add weights as a parameter """ - def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False, custom_thresholds=None): + def __init__(self, rule_text: str, name: str = None, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False, custom_thresholds=None): """ :param rule_text: The rule in text format :param name: The name of the rule. This will appear in the rule trace :param infer_edges: Whether to infer new edges after edge rule fires :param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change :param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied + :param custom_thresholds: A list of custom thresholds for the rule. If not specified, the default thresholds for ANY are used. It can be a list of + size #of clauses or a map of clause index to threshold """ - if custom_thresholds is None: - custom_thresholds = [] self.rule = rule_parser.parse_rule(rule_text, name, custom_thresholds, infer_edges, set_static, immediate_rule) diff --git a/pyreason/scripts/rules/rule_internal.py b/pyreason/scripts/rules/rule_internal.py index 69de3ca..c5a6054 100755 --- a/pyreason/scripts/rules/rule_internal.py +++ b/pyreason/scripts/rules/rule_internal.py @@ -1,9 +1,10 @@ class Rule: - def __init__(self, rule_name, rule_type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule): + def __init__(self, rule_name, rule_type, target, head_variables, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule): self._rule_name = rule_name self._type = rule_type self._target = target + self._head_variables = head_variables self._delta = delta self._clauses = clauses self._bnd = bnd @@ -17,23 +18,35 @@ def __init__(self, rule_name, rule_type, target, delta, clauses, bnd, thresholds def get_rule_name(self): return self._rule_name + def set_rule_name(self, rule_name): + self._rule_name = rule_name + def get_rule_type(self): return self._type def get_target(self): return self._target + def get_head_variables(self): + return self._head_variables + def get_delta(self): return self._delta - def get_neigh_criteria(self): + def get_clauses(self): return self._clauses + + def set_clauses(self, clauses): + self._clauses = clauses def get_bnd(self): return self._bnd def get_thresholds(self): - return self._thresholds + return self._thresholds + + def set_thresholds(self, thresholds): + self._thresholds = thresholds def get_annotation_function(self): return self._ann_fn diff --git a/pyreason/scripts/threshold/threshold.py b/pyreason/scripts/threshold/threshold.py index 3972263..1a4ee64 100644 --- a/pyreason/scripts/threshold/threshold.py +++ b/pyreason/scripts/threshold/threshold.py @@ -38,4 +38,4 @@ def to_tuple(self): Returns: tuple: A tuple representation of the Threshold instance. """ - return (self.quantifier, self.quantifier_type, self.thresh) \ No newline at end of file + return self.quantifier, self.quantifier_type, self.thresh diff --git a/pyreason/scripts/utils/fact_parser.py b/pyreason/scripts/utils/fact_parser.py new file mode 100644 index 0000000..6b3c922 --- /dev/null +++ b/pyreason/scripts/utils/fact_parser.py @@ -0,0 +1,40 @@ +import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval + + +def parse_fact(fact_text): + f = fact_text.replace(' ', '') + + # Separate into predicate-component and bound. If there is no bound it means it's true + if ':' in f: + pred_comp, bound = f.split(':') + else: + pred_comp = f + if pred_comp[0] == '~': + bound = 'False' + pred_comp = pred_comp[1:] + else: + bound = 'True' + + # Check if bound is a boolean or a list of floats + bound = bound.lower() + if bound == 'true': + bound = interval.closed(1, 1) + elif bound == 'false': + bound = interval.closed(0, 0) + else: + bound = [float(b) for b in bound[1:-1].split(',')] + bound = interval.closed(*bound) + + # Split the predicate and component + idx = pred_comp.find('(') + pred = pred_comp[:idx] + component = pred_comp[idx + 1:-1] + + # Check if it is a node or edge fact + if ',' in component: + fact_type = 'edge' + component = tuple(component.split(',')) + else: + fact_type = 'node' + + return pred, component, bound, fact_type diff --git a/pyreason/scripts/utils/output.py b/pyreason/scripts/utils/output.py index e0d12f9..e680083 100755 --- a/pyreason/scripts/utils/output.py +++ b/pyreason/scripts/utils/output.py @@ -4,8 +4,9 @@ class Output: - def __init__(self, timestamp): + def __init__(self, timestamp, clause_map=None): self.timestamp = timestamp + self.clause_map = clause_map self.rule_trace_node = None self.rule_trace_edge = None @@ -80,6 +81,14 @@ def _parse_internal_rule_trace(self, interpretation): # Store the trace in a DataFrame self.rule_trace_edge = pd.DataFrame(data, columns=header_edge) + # Now do the reordering + if self.clause_map is not None: + offset = 7 + columns_to_reorder_node = header_node[offset:] + columns_to_reorder_edge = header_edge[offset:] + self.rule_trace_node = self.rule_trace_node.apply(self._reorder_row, axis=1, map_dict=self.clause_map, columns_to_reorder=columns_to_reorder_node) + self.rule_trace_edge = self.rule_trace_edge.apply(self._reorder_row, axis=1, map_dict=self.clause_map, columns_to_reorder=columns_to_reorder_edge) + def save_rule_trace(self, interpretation, folder='./'): if self.rule_trace_node is None and self.rule_trace_edge is None: self._parse_internal_rule_trace(interpretation) @@ -94,3 +103,14 @@ def get_rule_trace(self, interpretation): self._parse_internal_rule_trace(interpretation) return self.rule_trace_node, self.rule_trace_edge + + @staticmethod + def _reorder_row(row, map_dict, columns_to_reorder): + if row['Occurred Due To'] in map_dict: + original_values = row[columns_to_reorder].values + new_values = [None] * len(columns_to_reorder) + for orig_pos, target_pos in map_dict[row['Occurred Due To']].items(): + new_values[target_pos] = original_values[orig_pos] + for i, col in enumerate(columns_to_reorder): + row[col] = new_values[i] + return row diff --git a/pyreason/scripts/utils/reorder_clauses.py b/pyreason/scripts/utils/reorder_clauses.py new file mode 100644 index 0000000..11408ff --- /dev/null +++ b/pyreason/scripts/utils/reorder_clauses.py @@ -0,0 +1,30 @@ +import numba +import pyreason.scripts.numba_wrapper.numba_types.label_type as label +import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval + + +def reorder_clauses(rule): + # Go through all clauses in the rule and re-order them if necessary + # It is faster for grounding to have node clauses first and then edge clauses + # Move all the node clauses to the front of the list + reordered_clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string))) + reordered_thresholds = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, numba.types.UniTuple(numba.types.string, 2), numba.types.float64))) + node_clauses = [] + edge_clauses = [] + reordered_clauses_map = {} + + for index, clause in enumerate(rule.get_clauses()): + if clause[0] == 'node': + node_clauses.append((index, clause)) + else: + edge_clauses.append((index, clause)) + + thresholds = rule.get_thresholds() + for new_index, (original_index, clause) in enumerate(node_clauses + edge_clauses): + reordered_clauses.append(clause) + reordered_thresholds.append(thresholds[original_index]) + reordered_clauses_map[new_index] = original_index + + rule.set_clauses(reordered_clauses) + rule.set_thresholds(reordered_thresholds) + return rule, reordered_clauses_map diff --git a/pyreason/scripts/utils/rule_parser.py b/pyreason/scripts/utils/rule_parser.py index 36bdede..1741911 100644 --- a/pyreason/scripts/utils/rule_parser.py +++ b/pyreason/scripts/utils/rule_parser.py @@ -1,12 +1,15 @@ import numba import numpy as np +from typing import Union import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule +# import pyreason.scripts.rules.rule_internal as rule import pyreason.scripts.numba_wrapper.numba_types.label_type as label import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval +from pyreason.scripts.threshold.threshold import Threshold -def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule: +def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, dict], infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule: # First remove all spaces from line r = rule_text.replace(' ', '') @@ -33,7 +36,7 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: # 2. replace ) by )) and ] by ]] so that we can split without damaging the string # 3. Split with ), and then for each element of list, split with ], and add to new list # 4. Then replace ]] with ] and )) with ) in for loop - # 5. Add :[1,1] to the end of each element if a bound is not specified + # 5. Add :[1,1] or :[0,0] to the end of each element if a bound is not specified # 6. Then split each element with : # 7. Transform bound strings into pr.intervals @@ -54,7 +57,9 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: # 5 for i in range(len(split_body)): - if split_body[i][-1] != ']': + if split_body[i][0] == '~': + split_body[i] = split_body[i][1:] + ':[0,0]' + elif split_body[i][-1] != ']': split_body[i] += ':[1,1]' # 6 @@ -65,6 +70,14 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: body_clauses.append(clause) body_bounds.append(bound) + # Check if there are custom thresholds for the rule such as forall in string form + for i, b in enumerate(body_clauses.copy()): + if 'forall(' in b: + if not custom_thresholds: + custom_thresholds = {} + custom_thresholds[i] = Threshold("greater_equal", ("percent", "total"), 100) + body_clauses[i] = b[:-1].replace('forall(', '') + # 7 for i in range(len(body_bounds)): bound = body_bounds[i] @@ -79,7 +92,10 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: # This means there is no bound or annotation function specified if head[-1] == ')': - head += ':[1,1]' + if head[0] == '~': + head = head[1:] + ':[0,0]' + else: + head += ':[1,1]' head, head_bound = head.split(':') # Check if we have a bound or annotation function @@ -123,25 +139,6 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: if rule_type == 'node': infer_edges = False - # Replace the variables in the body with source/target if they match the variables in the head - # If infer_edges is true, then we consider all rules to be node rules, we infer the 2nd variable of the target predicate from the rule body - # Else we consider the rule to be an edge rule and replace variables with source/target - # Node rules with possibility of adding edges - if infer_edges or len(head_variables) == 1: - head_source_variable = head_variables[0] - for i in range(len(body_variables)): - for j in range(len(body_variables[i])): - if body_variables[i][j] == head_source_variable: - body_variables[i][j] = '__target' - # Edge rule, no edges to be added - elif len(head_variables) == 2: - for i in range(len(body_variables)): - for j in range(len(body_variables[i])): - if body_variables[i][j] == head_variables[0]: - body_variables[i][j] = '__source' - elif body_variables[i][j] == head_variables[1]: - body_variables[i][j] = '__target' - # Start setting up clauses # clauses = [c1, c2, c3, c4] # thresholds = [t1, t2, t3, t4] @@ -155,18 +152,25 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: # gather count of clauses for threshold validation num_clauses = len(body_clauses) - if custom_thresholds and (len(custom_thresholds) != num_clauses): - raise Exception('The length of custom thresholds {} is not equal to number of clauses {}' - .format(len(custom_thresholds), num_clauses)) - + if isinstance(custom_thresholds, list): + if len(custom_thresholds) != num_clauses: + raise Exception(f'The length of custom thresholds {len(custom_thresholds)} is not equal to number of clauses {num_clauses}') + for threshold in custom_thresholds: + thresholds.append(threshold.to_tuple()) + elif isinstance(custom_thresholds, dict): + if max(custom_thresholds.keys()) >= num_clauses: + raise Exception(f'The max clause index in the custom thresholds map {max(custom_thresholds.keys())} is greater than number of clauses {num_clauses}') + for i in range(num_clauses): + if i in custom_thresholds: + thresholds.append(custom_thresholds[i].to_tuple()) + else: + thresholds.append(('greater_equal', ('number', 'total'), 1.0)) + # If no custom thresholds provided, use defaults # otherwise loop through user-defined thresholds and convert to numba compatible format - if not custom_thresholds: + elif not custom_thresholds: for _ in range(num_clauses): thresholds.append(('greater_equal', ('number', 'total'), 1.0)) - else: - for threshold in custom_thresholds: - thresholds.append(threshold.to_tuple()) # # Loop though clauses for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds): @@ -184,15 +188,18 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: # Assert that there are two variables in the head of the rule if we infer edges # Add edges between head variables if necessary if infer_edges: - var = '__target' if head_variables[0] == head_variables[1] else head_variables[1] - edges = ('__target', var, target) + # var = '__target' if head_variables[0] == head_variables[1] else head_variables[1] + # edges = ('__target', var, target) + edges = (head_variables[0], head_variables[1], target) else: edges = ('', '', label.Label('')) weights = np.ones(len(body_predicates), dtype=np.float64) weights = np.append(weights, 0) - r = rule.Rule(name, rule_type, target, numba.types.uint16(t), clauses, target_bound, thresholds, ann_fn, weights, edges, set_static, immediate_rule) + head_variables = numba.typed.List(head_variables) + + r = rule.Rule(name, rule_type, target, head_variables, numba.types.uint16(t), clauses, target_bound, thresholds, ann_fn, weights, edges, set_static, immediate_rule) return r diff --git a/requirements.txt b/requirements.txt index 655c2c3..25f9bc1 100755 --- a/requirements.txt +++ b/requirements.txt @@ -2,11 +2,12 @@ networkx pyyaml pandas numba==0.59.1 -numpy +numpy==1.26.4 memory_profiler pytest +setuptools_scm sphinx_rtd_theme sphinx sphinx-autopackagesummary -sphinx-autoapi \ No newline at end of file +sphinx-autoapi diff --git a/setup.py b/setup.py index 5cc9b95..8705509 100644 --- a/setup.py +++ b/setup.py @@ -4,11 +4,11 @@ from pathlib import Path this_directory = Path(__file__).parent -long_description = (this_directory / "README.md").read_text() +long_description = (this_directory / "README.md").read_text(encoding='UTF-8') setup( name='pyreason', - version='2.3.0', + version='3.0.0', author='Dyuman Aditya', author_email='dyuman.aditya@gmail.com', description='An explainable inference software supporting annotated, real valued, graph based and temporal logic', @@ -35,6 +35,8 @@ 'memory_profiler', 'pytest' ], + use_scm_version=True, + setup_requires=['setuptools_scm'], packages=find_packages(), include_package_data=True ) diff --git a/tests/group_chat_graph.graphml b/tests/group_chat_graph.graphml index 7c76e29..852d05c 100644 --- a/tests/group_chat_graph.graphml +++ b/tests/group_chat_graph.graphml @@ -1,26 +1,23 @@ - - - + + - - 1 + 1 - 1 - - - - 1 + 1 - 1 + 1 + + + 1 diff --git a/tests/knowledge_graph_test_subset.graphml b/tests/knowledge_graph_test_subset.graphml new file mode 100644 index 0000000..72e5c23 --- /dev/null +++ b/tests/knowledge_graph_test_subset.graphml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + 1 + + + + \ No newline at end of file diff --git a/tests/test_annotation_function.py b/tests/test_annotation_function.py new file mode 100644 index 0000000..22deb10 --- /dev/null +++ b/tests/test_annotation_function.py @@ -0,0 +1,36 @@ +# Test if annotation functions work +import pyreason as pr +import numba +import numpy as np + + +@numba.njit +def probability_func(annotations, weights): + prob_A = annotations[0][0].lower + prob_B = annotations[1][0].lower + union_prob = prob_A + prob_B + union_prob = np.round(union_prob, 3) + return union_prob, 1 + + +def test_annotation_function(): + # Reset PyReason + pr.reset() + pr.reset_rules() + + pr.settings.allow_ground_rules = True + + pr.add_fact(pr.Fact('P(A) : [0.01, 1]')) + pr.add_fact(pr.Fact('P(B) : [0.2, 1]')) + pr.add_annotation_function(probability_func) + pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True)) + + interpretation = pr.reason(timesteps=1) + + dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + + assert interpretation.query('union_probability(A, B) : [0.21, 1]'), 'Union probability should be 0.21' diff --git a/tests/test_anyBurl_infer_edges_rules.py b/tests/test_anyBurl_infer_edges_rules.py new file mode 100644 index 0000000..0ae5df7 --- /dev/null +++ b/tests/test_anyBurl_infer_edges_rules.py @@ -0,0 +1,139 @@ +import pyreason as pr + + +def test_anyBurl_rule_1(): + graph_path = './tests/knowledge_graph_test_subset.graphml' + pr.reset() + pr.reset_rules() + # Modify pyreason settings to make verbose and to save the rule trace to a file + pr.settings.verbose = True + pr.settings.atom_trace = True + pr.settings.memory_profile = False + pr.settings.canonical = True + pr.settings.inconsistency_check = False + pr.settings.static_graph_facts = False + pr.settings.output_to_file = False + pr.settings.store_interpretation_changes = True + pr.settings.save_graph_attributes_to_trace = True + # Load all the files into pyreason + pr.load_graphml(graph_path) + pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True)) + + # Run the program for two timesteps to see the diffusion take place + interpretation = pr.reason(timesteps=1) + # pr.save_rule_trace(interpretation) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' + assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' + assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' + + +def test_anyBurl_rule_2(): + graph_path = './tests/knowledge_graph_test_subset.graphml' + pr.reset() + pr.reset_rules() + # Modify pyreason settings to make verbose and to save the rule trace to a file + pr.settings.verbose = True + pr.settings.atom_trace = True + pr.settings.memory_profile = False + pr.settings.canonical = True + pr.settings.inconsistency_check = False + pr.settings.static_graph_facts = False + pr.settings.output_to_file = False + pr.settings.store_interpretation_changes = True + pr.settings.save_graph_attributes_to_trace = True + pr.settings.parallel_computing = False + # Load all the files into pyreason + pr.load_graphml(graph_path) + + pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_2', infer_edges=True)) + + # Run the program for two timesteps to see the diffusion take place + interpretation = pr.reason(timesteps=1) + # pr.save_rule_trace(interpretation) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' + assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' + assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' + + +def test_anyBurl_rule_3(): + graph_path = './tests/knowledge_graph_test_subset.graphml' + pr.reset() + pr.reset_rules() + # Modify pyreason settings to make verbose and to save the rule trace to a file + pr.settings.verbose = True + pr.settings.atom_trace = True + pr.settings.memory_profile = False + pr.settings.canonical = True + pr.settings.inconsistency_check = False + pr.settings.static_graph_facts = False + pr.settings.output_to_file = False + pr.settings.store_interpretation_changes = True + pr.settings.save_graph_attributes_to_trace = True + pr.settings.parallel_computing = False + # Load all the files into pyreason + pr.load_graphml(graph_path) + + pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_3', infer_edges=True)) + + # Run the program for two timesteps to see the diffusion take place + interpretation = pr.reason(timesteps=1) + # pr.save_rule_trace(interpretation) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' + assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' + assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps' + + +def test_anyBurl_rule_4(): + graph_path = './tests/knowledge_graph_test_subset.graphml' + pr.reset() + pr.reset_rules() + # Modify pyreason settings to make verbose and to save the rule trace to a file + pr.settings.verbose = True + pr.settings.atom_trace = True + pr.settings.memory_profile = False + pr.settings.canonical = True + pr.settings.inconsistency_check = False + pr.settings.static_graph_facts = False + pr.settings.output_to_file = False + pr.settings.store_interpretation_changes = True + pr.settings.save_graph_attributes_to_trace = True + pr.settings.parallel_computing = False + # Load all the files into pyreason + pr.load_graphml(graph_path) + + pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_4', infer_edges=True)) + + # Run the program for two timesteps to see the diffusion take place + interpretation = pr.reason(timesteps=1) + # pr.save_rule_trace(interpretation) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' + assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' + assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' diff --git a/tests/test_custom_thresholds.py b/tests/test_custom_thresholds.py index b982bf3..e1ae437 100644 --- a/tests/test_custom_thresholds.py +++ b/tests/test_custom_thresholds.py @@ -11,7 +11,8 @@ def test_custom_thresholds(): # Modify the paths based on where you've stored the files we made above graph_path = "./tests/group_chat_graph.graphml" - # Modify pyreason settings to make verbose and to save the rule trace to a file + # Modify pyreason settings to make verbose + pr.reset_settings() pr.settings.verbose = True # Print info to screen # Load all the files into pyreason @@ -25,16 +26,16 @@ def test_custom_thresholds(): pr.add_rule( pr.Rule( - "ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)", + "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)", "viewed_by_all_rule", custom_thresholds=user_defined_thresholds, ) ) - pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 3)) - pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 3)) - pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 3)) - pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 3)) + pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3)) + pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3)) + pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3)) + pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3)) # Run the program for three timesteps to see the diffusion take place interpretation = pr.reason(timesteps=3) diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py index c932daf..84e84d2 100644 --- a/tests/test_hello_world.py +++ b/tests/test_hello_world.py @@ -1,24 +1,28 @@ # Test if the simple hello world program works import pyreason as pr +import faulthandler def test_hello_world(): # Reset PyReason pr.reset() pr.reset_rules() + pr.reset_settings() # Modify the paths based on where you've stored the files we made above graph_path = './tests/friends_graph.graphml' - # Modify pyreason settings to make verbose and to save the rule trace to a file + # Modify pyreason settings to make verbose pr.settings.verbose = True # Print info to screen + # pr.settings.optimize_rules = False # Disable rule optimization for debugging # Load all the files into pyreason pr.load_graphml(graph_path) pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2)) + pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2)) # Run the program for two timesteps to see the diffusion take place + faulthandler.enable() interpretation = pr.reason(timesteps=2) # Display the changes in the interpretation for each timestep @@ -29,8 +33,8 @@ def test_hello_world(): print() assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[1]) == 2, 'At t=0 there should be two popular people' - assert len(dataframes[2]) == 3, 'At t=0 there should be three popular people' + assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people' + assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people' # Mary should be popular in all three timesteps assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' diff --git a/tests/test_hello_world_parallel.py b/tests/test_hello_world_parallel.py index 1b7ee03..fe47a33 100644 --- a/tests/test_hello_world_parallel.py +++ b/tests/test_hello_world_parallel.py @@ -10,14 +10,14 @@ def test_hello_world_parallel(): # Modify the paths based on where you've stored the files we made above graph_path = './tests/friends_graph.graphml' - # Modify pyreason settings to make verbose and to save the rule trace to a file + # Modify pyreason settings to make verbose + pr.reset_settings() pr.settings.verbose = True # Print info to screen - pr.settings.parallel_computing = True # Load all the files into pyreason pr.load_graphml(graph_path) pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2)) + pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2)) # Run the program for two timesteps to see the diffusion take place interpretation = pr.reason(timesteps=2) @@ -30,8 +30,8 @@ def test_hello_world_parallel(): print() assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[1]) == 2, 'At t=0 there should be two popular people' - assert len(dataframes[2]) == 3, 'At t=0 there should be three popular people' + assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people' + assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people' # Mary should be popular in all three timesteps assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' diff --git a/tests/test_reorder_clauses.py b/tests/test_reorder_clauses.py new file mode 100644 index 0000000..6407f9b --- /dev/null +++ b/tests/test_reorder_clauses.py @@ -0,0 +1,52 @@ +# Test if the simple hello world program works +import pyreason as pr + + +def test_reorder_clauses(): + # Reset PyReason + pr.reset() + pr.reset_rules() + pr.reset_settings() + + # Modify the paths based on where you've stored the files we made above + graph_path = './tests/friends_graph.graphml' + + # Modify pyreason settings to make verbose + pr.settings.verbose = True # Print info to screen + pr.settings.atom_trace = True # Print atom trace + + # Load all the files into pyreason + pr.load_graphml(graph_path) + pr.add_rule(pr.Rule('popular(x) <-1 Friends(x,y), popular(y), owns(y,z), owns(x,z)', 'popular_rule')) + pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2)) + + # Run the program for two timesteps to see the diffusion take place + interpretation = pr.reason(timesteps=2) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + + assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person' + assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people' + assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people' + + # Mary should be popular in all three timesteps + assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' + assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' + assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' + + # Justin should be popular in timesteps 1, 2 + assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' + assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' + + # John should be popular in timestep 3 + assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' + + # Now look at the trace and make sure the order has gone back to the original rule + # The second row, clause 1 should be the edge grounding ('Justin', 'Mary') + rule_trace_node, _ = pr.get_rule_trace(interpretation) + assert rule_trace_node.iloc[2]['Clause-1'][0] == ('Justin', 'Mary')