From ef6a737e5ef55279cf65272a2a3c2c740d13f0b3 Mon Sep 17 00:00:00 2001 From: kaustubh-darekar Date: Wed, 15 Jan 2025 06:14:04 +0000 Subject: [PATCH] rectified code to not include Document node while graph_consolidation --- backend/score.py | 17 +++++++++-------- backend/src/main.py | 2 +- backend/src/post_processing.py | 10 +++++++++- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/backend/score.py b/backend/score.py index 6869b1b8..e5a416fc 100644 --- a/backend/score.py +++ b/backend/score.py @@ -346,14 +346,15 @@ async def post_processing(uri=Form(), userName=Form(), password=Form(), database await asyncio.to_thread(create_communities, uri, userName, password, database) logging.info(f'created communities') - graph = create_graph_database_connection(uri, userName, password, database) - graphDb_data_Access = graphDBdataAccess(graph) - document_name = "" - count_response = graphDb_data_Access.update_node_relationship_count(document_name) - if count_response: - count_response = [{"filename": filename, **counts} for filename, counts in count_response.items()] - logging.info(f'Updated source node with community related counts') - + + + graph = create_graph_database_connection(uri, userName, password, database) + graphDb_data_Access = graphDBdataAccess(graph) + document_name = "" + count_response = graphDb_data_Access.update_node_relationship_count(document_name) + if count_response: + count_response = [{"filename": filename, **counts} for filename, counts in count_response.items()] + logging.info(f'Updated source node with community related counts') end = time.time() elapsed_time = end - start diff --git a/backend/src/main.py b/backend/src/main.py index 85274036..99d6407d 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -676,7 +676,7 @@ def get_labels_and_relationtypes(graph): query = """ RETURN collect { CALL db.labels() yield label - WHERE NOT label IN ['Chunk','_Bloom_Perspective_', '__Community__', '__Entity__'] + WHERE NOT label IN ['Document','Chunk','_Bloom_Perspective_', '__Community__', '__Entity__'] return label order by label limit 100 } as labels, collect { CALL db.relationshipTypes() yield relationshipType as type diff --git a/backend/src/post_processing.py b/backend/src/post_processing.py index 8b79f93b..cb7993a1 100644 --- a/backend/src/post_processing.py +++ b/backend/src/post_processing.py @@ -203,6 +203,12 @@ def graph_schema_consolidation(graph): node_labels.extend(nodes_and_relations[0]['labels']) relation_labels.extend(nodes_and_relations[0]['relationshipTypes']) + exclude_node_labels = ['Document','Chunk','_Bloom_Perspective_', '__Community__', '__Entity__'] + exclude_relationship_labels = ['PART_OF', 'NEXT_CHUNK', 'HAS_ENTITY', '_Bloom_Perspective_','FIRST_CHUNK','SIMILAR','IN_COMMUNITY','PARENT_COMMUNITY'] + + node_labels = [i for i in node_labels if i not in exclude_node_labels ] + relation_labels = [i for i in relation_labels if i not in exclude_relationship_labels] + parser = JsonOutputParser() prompt = ChatPromptTemplate(messages=[("system",GRAPH_CLEANUP_PROMPT),("human", "{input}")], partial_variables={"format_instructions": parser.get_format_instructions()}) @@ -225,8 +231,10 @@ def graph_schema_consolidation(graph): if new_label != old_label: relation_match[old_label]=new_label - logging.info(f"updated node labels : {node_match}") + logging.info(f"updated node labels : {node_match}") + logging.info(f"Reduced node counts from {len(node_labels)} to {len(node_match.items())}") logging.info(f"updated relationship labels : {relation_match}") + logging.info(f"Reduced relationship counts from {len(relation_labels)} to {len(relation_match.items())}") # Update node labels in graph for old_label, new_label in node_match.items():