Skip to content

Commit 1b79f32

Browse files
haohangyanbgyori
authored andcommitted
Use source count dict to get the evidence count in INDRA net
1 parent ed26006 commit 1b79f32

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

indra/assemblers/indranet/assembler.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def _store_edge_data(stmts, extra_columns=None):
518518

519519

520520
def statement_to_rows(stmt, exclude_stmts=None, complex_members=3,
521-
extra_columns=None, keep_self_loops=True):
521+
extra_columns=None, keep_self_loops=True, source_counts=None):
522522
rows = []
523523
if exclude_stmts:
524524
exclude_types = tuple(
@@ -592,6 +592,13 @@ def statement_to_rows(stmt, exclude_stmts=None, complex_members=3,
592592
pos = None
593593
# Create a simple flat list of just the values instead
594594
# of a dict with keys
595+
statemet_hash = stmt.get_hash(refresh=True)
596+
if source_counts:
597+
evidence_count = sum(source_counts.get(statemet_hash, {}).values())
598+
source_count = source_counts.get(statemet_hash, {})
599+
else:
600+
evidence_count = len(stmt.evidence)
601+
source_count = _get_source_counts(stmt)
595602
row = [
596603
agA.name,
597604
agB.name,
@@ -602,10 +609,10 @@ def statement_to_rows(stmt, exclude_stmts=None, complex_members=3,
602609
res,
603610
pos,
604611
stmt_type,
605-
len(stmt.evidence),
606-
stmt.get_hash(refresh=True),
612+
evidence_count,
613+
statemet_hash,
607614
stmt.belief,
608-
_get_source_counts(stmt),
615+
source_count,
609616
sign
610617
]
611618
if extra_columns:

0 commit comments

Comments
 (0)