Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(solver): support math opeator #88

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kag/solver/logic/core_modules/common/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self):
self.alias_name: Identifer = None
self.type_set: List[TypeInfo] = []
self.is_attribute = False
self.value_list = []
self.value_list = {}

def __repr__(self):
return f"{self.alias_name}:{self.get_entity_first_type_or_en()}"
Expand Down
8 changes: 7 additions & 1 deletion kag/solver/logic/core_modules/common/one_hop_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,12 @@ def __repr__(self):
from_entity_desc_str = "" if from_entity_desc is None else f"({from_entity_desc})"
to_entity_desc = self._get_entity_description(self.end_entity)
to_entity_desc_str = "" if to_entity_desc is None else f"({to_entity_desc})"
return f"({self.from_entity.name}{from_entity_desc_str} {self.type} {self.end_entity.name}{to_entity_desc_str})"
spo = f"({self.from_entity.name}{from_entity_desc_str} {self.type} {self.end_entity.name}{to_entity_desc_str})"
prop_map = self.prop.get_properties_map_list_value() if self.prop else {}
if prop_map:
prop_str = ",".join([f"{k}={';'.join(v)}" for k, v in prop_map.items()])
return f"{spo} with prop: {prop_str}"
return spo

@staticmethod
def from_dict(json_dict: dict, schema: SchemaUtils):
Expand Down Expand Up @@ -585,6 +590,7 @@ def __init__(self):
self.edge_alias = []
self.entity_map = {}
self.edge_map = {}
self.symb_values = {}

def merge_kg_graph(self, other, wo_intersect=True):
self.nodes_alias = list(set(self.nodes_alias + other.nodes_alias))
Expand Down
2 changes: 1 addition & 1 deletion kag/solver/logic/core_modules/lf_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def _execute_lf(self, sub_logic_nodes):
if isinstance(deduce_res, list):
kg_qa_result += deduce_res
elif self.math_executor.is_this_op(n):
self.math_executor.executor(n, self.req_id, self.params)
kg_qa_result = self.math_executor.executor(n, self.req_id, self.params)
elif self.sort_executor.is_this_op(n):
self.sort_executor.executor(n, self.req_id, self.params)
elif self.output_executor.is_this_op(n):
Expand Down
13 changes: 10 additions & 3 deletions kag/solver/logic/core_modules/op_executor/op_math/math_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.op_executor.op_executor import OpExecutor
from kag.solver.logic.core_modules.parser.logic_node_parser import CountNode, SumNode
from kag.solver.logic.core_modules.op_executor.op_math.sympy_math.sympy_math_op import SymPyMathOp
from kag.solver.logic.core_modules.parser.logic_node_parser import CountNode, SumNode, MathNode


class MathExecutor(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)

self.op_mapping = {
'math': SymPyMathOp(self.nl_query, self.kg_graph, self.schema, self.debug_info, **kwargs)
}

def is_this_op(self, logic_node: LogicNode) -> bool:
return isinstance(logic_node, (CountNode, SumNode))
return isinstance(logic_node, (CountNode, SumNode, MathNode))

def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> Union[KgGraph, list]:
pass
if isinstance(logic_node, MathNode):
return self.op_mapping['math'].executor(logic_node, req_id, param)
raise NotImplementedError(f"{logic_node}")
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import sympy as sp
from kag.solver.logic.core_modules.op_executor.op_math.sympy_math.unary.unary_set_op import CountSet, SumSet, \
AverageSet, MaxSet, MinSet, AbsSet

# Dictionary of custom functions for sympy
from sympy import Basic

custom_functions = {
'count': CountSet(),
'sum': SumSet(),
'average': AverageSet(),
'max': MaxSet(),
'min': MinSet(),
'abs': AbsSet()
}

custom_functions_call = {
'count': CountSet().process,
'sum': SumSet().process,
'average': AverageSet().process,
'max': MaxSet().process,
'min': MinSet().process,
'abs': AbsSet().process
}
def is_number(s):
"""检查字符串是否可以转换为数字。"""
try:
float(s)
return True
except ValueError:
return False
def evaluate_expression_eval(expression, data_dict):
"""
Evaluates a mathematical expression using SymPy and custom functions.

:param expression: A string representing the mathematical expression to be evaluated.
:param data_dict: A dictionary containing variable names and their corresponding values.
:return: The result of the evaluated expression.
"""
for key, value_list in data_dict.items():
# 检查列表是否不为空,并且所有元素都是数字
if value_list and all(is_number(value) for value in value_list):
# 将所有字符串转换为浮点数
data_dict[key] = [float(value) for value in value_list]
data_dict.update(custom_functions_call)
result = eval(expression, data_dict)

return result

def evaluate_expression_sympy(expression, data_dict):
"""
Evaluates a mathematical expression using SymPy and custom functions.

:param expression: A string representing the mathematical expression to be evaluated.
:param data_dict: A dictionary containing variable names and their corresponding values.
:return: The result of the evaluated expression.
"""
# Parse the expression using SymPy
expr = sp.sympify(expression)

# Extract variables from the data dictionary
variables = {sp.Symbol(key): value for key, value in data_dict.items()}

# Substitute the variables into the expression
result = expr.subs(variables)

# Process custom functions in the result
for func_name, func in custom_functions.items():
if func_name in str(result):
result = result.replace(sp.Function(func_name), lambda *args: func.process(*args))
if isinstance(result, Basic):
# Check if all functions in the result are implemented
for func in result.atoms(sp.Function):
if func.func.__name__ not in custom_functions:
raise NotImplementedError(f"Function '{func.func.__name__}' is not implemented.")

return result

def evaluate_expression(expression, data_dict):
return evaluate_expression_eval(expression, data_dict)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from sympy import FiniteSet

from kag.common.base.prompt_op import PromptOp
from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph, EntityData
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.op_executor.op_executor import OpExecutor
from kag.solver.logic.core_modules.parser.logic_node_parser import MathNode

from kag.solver.logic.core_modules.op_executor.op_math.sympy_math.custom_function import evaluate_expression


class SymPyMathOp(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.expression_builder = PromptOp.load(self.biz_scene, "expression_builder")(
language=self.language, project_id=self.project_id
)

def _convert_kg_graph_2_variable_data_dict(self):
data_set = {}
for p, spo in self.kg_graph.query_graph.items():
def convert_finite_set(alias_var):
alias = str(alias_var)
if alias not in data_set.keys():
alias_set = self.kg_graph.get_entity_by_alias(alias_var)
alias_set_data = []
if alias_set:
for alias_data in alias_set:
if isinstance(alias_data, EntityData):
alias_set_data.append(alias_data.biz_id)
else:
alias_set_data.append(str(alias_data))

data_set[alias] = alias_set_data
else:
data_set[alias] = alias_set_data

convert_finite_set(p)
convert_finite_set(spo["s"])
convert_finite_set(spo["o"])

data_set.update(self.kg_graph.symb_values)
return data_set

def executor(self, logic_node: MathNode, req_id: str, param: dict) -> list:
data_set = self._convert_kg_graph_2_variable_data_dict()
result = evaluate_expression(logic_node.expr, data_set)
self.kg_graph.symb_values[logic_node.alias_name] = result
return [result]
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from sympy import Symbol


class UnarySetOp:
"""
Base class for unary set operations.
"""

def __init__(self):
"""
Initializes the base unary set operation.
"""
self.op_name = 'base'

def pre_check(self, A):
"""
Pre-checks the input to ensure it is valid.

:param A: The input set or symbol.
:raises ValueError: If the input is a symbolic variable.
"""
if isinstance(A, Symbol):
raise ValueError(f"Undefined variables: {A}")

def do_process(self, A):
"""
Abstract method to perform the specific operation on the input set.

:param A: The input set.
:raises NotImplementedError: If the operation is not implemented.
"""
raise NotImplementedError(f"{self.op_name} is not implemented")

def process(self, A):
"""
Processes the input set by performing the pre-check and then the specific operation.

:param A: The input set.
:return: The result of the operation.
"""
self.pre_check(A)
return self.do_process(A)


class CountSet(UnarySetOp):
"""
Class for counting the number of elements in a set.
"""

def __init__(self):
"""
Initializes the count set operation.
"""
super(CountSet, self).__init__()
self.op_name = "count"

def do_process(self, A):
"""
Counts the number of elements in the input set.

:param A: The input set.
:return: The number of elements in the set.
"""
return len(A)


class SumSet(UnarySetOp):
"""
Class for summing the elements in a set.
"""

def __init__(self):
"""
Initializes the sum set operation.
"""
super(SumSet, self).__init__()
self.op_name = "sum"

def do_process(self, A):
"""
Sums the elements in the input set.

:param A: The input set.
:return: The sum of the elements in the set.
"""
return sum(A)


class AverageSet(UnarySetOp):
"""
Class for computing the average of elements in a set.
"""

def __init__(self):
"""
Initializes the average set operation.
"""
super(AverageSet, self).__init__()
self.op_name = "average"

def do_process(self, A):
"""
Computes the average of the elements in the input set.

:param A: The input set.
:return: The average of the elements in the set.
:raises ValueError: If the input set is empty.
"""
if len(A) == 0:
raise ValueError("Cannot compute average of an empty set")
return sum(A) / len(A)


class MaxSet(UnarySetOp):
"""
Class for finding the maximum element in a set.
"""

def __init__(self):
"""
Initializes the max set operation.
"""
super(MaxSet, self).__init__()
self.op_name = "max"

def do_process(self, A):
"""
Finds the maximum element in the input set.

:param A: The input set.
:return: The maximum element in the set.
:raises ValueError: If the input set is empty.
"""
if len(A) == 0:
raise ValueError("Cannot compute max of an empty set")
return max(A)


class MinSet(UnarySetOp):
"""
Class for finding the minimum element in a set.
"""

def __init__(self):
"""
Initializes the min set operation.
"""
super(MinSet, self).__init__()
self.op_name = "min"

def do_process(self, A):
"""
Finds the minimum element in the input set.

:param A: The input set.
:return: The minimum element in the set.
:raises ValueError: If the input set is empty.
"""
if len(A) == 0:
raise ValueError("Cannot compute min of an empty set")
return min(A)


class AbsSet(UnarySetOp):
"""
Class for computing the absolute values of elements in a set.
"""

def __init__(self):
"""
Initializes the abs set operation.
"""
super(AbsSet, self).__init__()
self.op_name = "abs"

def do_process(self, A):
"""
Computes the absolute values of the elements in the input set.

:param A: The input set.
:return: A set of absolute values of the elements in the input set.
"""
return {abs(x) for x in A}
Loading
Loading